In [None]:
from IPython.display import clear_output
import numpy as np
import numba
import time
from matplotlib import pyplot as plt

%matplotlib inline

In [None]:
LED_COLS = 8
LED_ROWS = 8

In [None]:
@numba.njit(fastmath=True)
def calculate_state_next(state, state_next):
    shape_x, shape_y = state.shape
    for x in range(shape_x):
        for y in range(shape_y):
            nn = 0
            if x > 0:
                if y > 0: nn += state[x-1,y-1]
                nn += state[x-1,y]
                if y < shape_y-1: nn += state[x-1,y+1]

            if y > 0: nn += state[x,y-1]
            if y < shape_y-1: nn += state[x,y+1]

            if x < shape_x-1:
                if y > 0: nn += state[x+1,y-1]
                nn += state[x+1,y]
                if y < shape_y-1: nn += state[x+1,y+1]
            
            if state[x,y] == 0:
                state_next[x,y] = 1 if nn == 3 else 0
            else:
                state_next[x,y] = 1 if nn in (2, 3) else 0;

@numba.njit(fastmath=True) # parallel=True
def update_state(state, steps=1):
    state_tmp = np.empty_like(state)
    for k in range(steps):
        calculate_state_next(state, state_tmp)
        # swapping pointers
        state_tmp2 = state; state = state_tmp; state_tmp = state_tmp2
    if steps % 2 == 1:
        state_tmp[:] = state[:]

@numba.njit(fastmath=True, parallel=True) # parallel=True
def update_states(states, steps=1):
    n_states = states.shape[0]
    for n in numba.prange(n_states):
        update_state(states[n], steps)

In [None]:
def plot_state(state):
    clear_output(wait=True)
    plt.figure(figsize=(6,6))
    plt.pcolormesh(state, cmap='nipy_spectral', edgecolors='k', lw=6, vmin=0, vmax=1.5)
    plt.axis('off')
    plt.show();

state = np.random.randint(2, size=(LED_COLS, LED_ROWS), dtype=np.uint8)
# state[:] = 0; state[0::2,:] = 1
# state_next = np.empty_like(state)
while True:
    try:
#         calculate_state_next(state, state_next)
#         tmp = state; state = state_next; state_next = tmp
        update_state(state, steps=1)
        plot_state(state)
        time.sleep(1)
    except KeyboardInterrupt:
        break

In [None]:
steps = np.power(2, np.arange(1,16+1)).astype(np.uint32)
dt = []
for steps_ in steps:
    time0 = time.time()
    
    state = np.random.randint(2, size=(LED_COLS, LED_ROWS), dtype=np.uint8)
    # state_next = np.empty_like(state)
    # for k in range(N_):
    #     calculate_state_next(state, state_next)
    #     tmp = state; state = state_next; state_next = tmp
    update_state(state, steps_)
    
    dt_ = time.time() - time0
    dt.append(dt_)
    if steps_ == steps[-1]:
        print('%d takes = %g sec' % (steps_, dt_))
plt.plot(steps, dt, 'o-')
plt.xscale('log');  plt.yscale('log')
plt.xlabel('steps');  plt.ylabel('time [sec]')
plt.show()

# python code: 32768 takes = 12.5371 sec
# @numba.njit: 32768 takes = 0.0366704 sec
# @numba.njit: 32768 takes = 0.037045 sec
# @numba.njit(fastmath=True): 32768 takes = 0.0372975 sec
# loop inside numba: 32768 takes = 0.0162888 sec

In [None]:
steps = 2**16
n_states = np.power(2, np.arange(1,11)).astype(np.uint32)
dt = []
for n_states_ in n_states:
    time0 = time.time()
    
    states = np.random.randint(2, size=(n_states_, LED_COLS, LED_ROWS), dtype=np.uint8)
    update_states(states, steps)
    
    dt_ = time.time() - time0
    dt.append(dt_)
    print('n_states=%d takes = %g sec' % (n_states_, dt_))
plt.plot(n_states, dt, 'o-')
plt.xscale('log');  plt.yscale('log')
plt.xlabel('n_states');  plt.ylabel('time [sec]')
plt.show()