#### Imports & Boilerplate

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax import nnx
import imageio
# from IPython.display import Image
from IPython.display import Video
import os

# # Set the default precision to float32. prevents weird precision problems when training some tasks on GPU due to different defaults (maybe useful when using gradients)
# jax.config.update("jax_default_matmul_precision", "float32")
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # https://docs.jax.dev/en/latest/gpu_memory_allocation.html
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" # https://docs.jax.dev/en/latest/gpu_memory_allocation.html
os.environ["XLA_GPU_STRICT_CONV_ALGORITHM_PICKER"] = "false"

In [12]:
# code for visualizing the steps of an NCA rollout
def visualize(frames, fps=10, scale=5):
    # frames are from CellularAutomataNet.run_until_stable
    # pixels with 4 channels between 0 and 1 that represent RGBA
    
    # Convert frames to numpy arrays if they're JAX arrays
    frames = [np.array(frame) for frame in frames]
    
    # Create a black background
    background_color = np.array([0.0, 0.0, 0.0, 1.0])
    
    # Process frames to properly handle alpha channel
    processed_frames = []
    for frame in frames:
        # Extract RGB and alpha channels
        rgb = frame[..., :3]
        alpha = frame[..., 3:4]
        
        # Alpha blend the RGB with background
        # Formula: result = alpha * foreground + (1 - alpha) * background
        blended = alpha * rgb + (1 - alpha) * background_color[:3]
        
        # Convert to uint8 for imageio
        frame_uint8 = (blended * 255).astype(np.uint8)
        processed_frames.append(frame_uint8)
    
    # Scale up the frames to make pixels more visible
    scaled_frames = []
    for frame in processed_frames:
        # Use simple nearest neighbor scaling to preserve pixel boundaries
        h, w, c = frame.shape
        scaled = np.repeat(np.repeat(frame, scale, axis=0), scale, axis=1)
        scaled_frames.append(scaled)
    
    # Create a temporary file for the video
    filename = "nca.mp4"
    
    # Write frames to the video file
    imageio.mimsave(filename, scaled_frames, fps=fps, codec='libx264', macro_block_size=None)
    
    # Display the video
    display(Video(filename, embed=True))

#### Neural Cellular Automata Architecture

In [13]:
# class for a simple neural network model
class CellularAutomataNet(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):  # input_shape is a tuple (height, width) of the input grid
        super().__init__()
        if rngs is None:
            rngs = nnx.Rngs(0)

        # self.rngs = rngs
        # channels are just RGB for now, but could add channels e.g. for opacity or any other traits or hidden states
        self.num_channels = 4
        # filter with depth num_channels, kernel size 3x3, producing 30 channels
        # NOTE: to be more like the original NCA paper, this should be a depthwise convolution instead (feature_group_count=num_channels)
        self.conv1 = nnx.Conv(in_features=self.num_channels, out_features=self.num_channels, kernel_size=(3, 3), padding='VALID', feature_group_count=1, rngs=rngs, dtype=jnp.float16)
        self.linear1 = nnx.Linear(self.num_channels, self.num_channels, rngs=rngs, dtype=jnp.float16)
        self.linear2 = nnx.Linear(self.num_channels, self.num_channels, rngs=rngs, dtype=jnp.float16)

        # # initial behavior should be "do-nothing", so initialize the weights to be 0
        # self.conv1.kernel *= 0
        # self.linear1.kernel *= 0
        # self.linear2.kernel *= 0
    
    def pad_input(self, x, pad_value=-1):
        """Manually pad input with a custom value."""
        pad_size = 1  # Kernel size (3x3) means 1 pixel padding on each side
        
        if x.ndim == 3:  # (height, width, channels)
            pad_width = [(pad_size, pad_size), (pad_size, pad_size), (0, 0)]
        elif x.ndim == 4:  # (batch, height, width, channels)
            pad_width = [(0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)]
        else:
            raise ValueError(f"Unexpected input shape {x.shape}")

        return jnp.pad(x, pad_width, mode='constant', constant_values=pad_value)
    
    @nnx.jit
    def __call__(self, grid, rng, mask_prob):
        assert grid.shape[-1] == self.num_channels
        initial_grid = grid

        # # padding the input grid with -1
        # # (same as padding='SAME' but not with zero, which already represents black. this way it distinguishes border and black pixels, which can matter for some tasks)
        # # (nevermind.. black is [1, 0, ..., 0]. zero padding is [0, 0, ..., 0])
        grid = self.pad_input(grid, pad_value=-1) # specifically a different value than the black color
        # grid = self.pad_input(grid, pad_value=-0.1)
        # grid = self.pad_input(grid, pad_value=0) # functionally the same as doing padding='SAME'

        # apply the convolutional layer
        grid = self.conv1(grid)
        # apply the activation function
        grid = jax.nn.sigmoid(grid)
        # apply the linear layer
        grid = self.linear1(grid)
        # apply the activation function
        grid = jax.nn.sigmoid(grid)
        # apply the linear layer
        grid = self.linear2(grid)

        # for outputting a delta grid every step: s_(t+1) = f(s_t) + s_t:
        # grid = jnp.tanh(grid) # tanh to keep the values between -1 and 1, could just do no activation function like in "Growing Neural Cellular Automata"

        # stochastic cell update: randomly mask some updates to the grid
        rng, subkey = jax.random.split(rng)
        mask = jax.random.bernoulli(subkey, p=mask_prob, shape=(grid.shape[0], grid.shape[1], 1))

        # mask will insert zeros randomly instead of the original grid values because this is a delta grid
        grid = jnp.where(mask, 0, grid) # mask_prob chance of masking an update
        # grid = jax.nn.sigmoid(grid) # sigmoid to keep the values between 0 and 1, so it can be interpreted as RGB but we don't let hidden states go outside this range for stability

        grid = initial_grid + grid

        # NOTE: maybe stochastic cell update should interpolate so that things are smoother?

        # # mask will insert the original grid values for the masked pixels
        # grid = jnp.where(mask, initial_grid, grid) # mask_prob chance of masking an update

        # interp = jax.random.uniform(subkey, shape=grid.shape, dtype=grid.dtype)
        # grid = jnp.where(mask, interp * grid, grid) # mask_prob chance of masking an update

        # grid = initial_grid + grid

        grid = jnp.clip(grid, 0, 1)  # clip the values to be between 0 and 1

        # Living cell masking. We want to model the growth process that starts with a single cell, and don’t want empty cells to participate in computations or carry any hidden state. We enforce this by explicitly setting all channels of empty cells to zeros. A cell is considered empty if there is no “mature” (alpha>0.1) cell in its 3x3 neightborhood.
        life_mask = nnx.max_pool(initial_grid[:, :, 3:4], window_shape=(3, 3), strides=(1, 1), padding='SAME') > 0.1
        grid = grid * life_mask  # set the channels of empty cells to zeros

        # return the output grid
        return grid, rng

    # one step of the cellular automata
    @nnx.jit
    def update(self, grid, rng, mask_prob):
        hidden_grid, rng = self(grid, rng, mask_prob)
        # # delta grid version
        # grid = jax.nn.softmax(hidden_grid, axis=-1)
        # grid = jnp.argmax(hidden_grid, axis=-1)
        
        grid = hidden_grid  # using hidden_grid directly
        # grid = jnp.clip(hidden_grid, 0, 1)  # clip the values to be between 0 and 1

        return hidden_grid, grid, rng # hidden used for next step, grid is the output

    # mask_prob is the probability of a pixel being update masked, so 0 is global updates / no stochasticity, 0.5 means ~half the pixels are updated per step
    @nnx.jit(static_argnames=["max_steps"])
    def run_until_stable(self, grid, max_steps=100, seed=0, mask_prob=0.5):
        rng = jax.random.PRNGKey(seed)

        # for i in range(max_steps):
        def step_fn(state, step):
            prev_grid, rng = state
            _, new_grid, rng = self.update(prev_grid, rng, mask_prob)  # using hidden_grid directly
            return (new_grid, rng), prev_grid
        
        state = (grid, rng)
        state, grids = jax.lax.scan(step_fn, state, jnp.arange(max_steps))

        return grids # return the grids
    
    @nnx.jit
    def set_params(self, params):
        # set the parameters of the model
        s = 0
        self.conv1.kernel.value = jnp.reshape(params[:self.conv1.kernel.size], self.conv1.kernel.shape)
        s += self.conv1.kernel.size
        self.conv1.bias.value = params[s:s + self.conv1.bias.size]
        s += self.conv1.bias.size
        self.linear1.kernel.value = jnp.reshape(params[s:s + self.linear1.kernel.size], self.linear1.kernel.shape)
        s += self.linear1.kernel.size
        self.linear1.bias.value = params[s:s + self.linear1.bias.size]
        s += self.linear1.bias.size
        self.linear2.kernel.value = jnp.reshape(params[s:s + self.linear2.kernel.size], self.linear2.kernel.shape)
        s += self.linear2.kernel.size
        self.linear2.bias.value = params[s:s + self.linear2.bias.size]
        s += self.linear2.bias.size
        return self


#### fitness function that determines "open-endedness" of an NCA

Courtesy of ChatGPT o3-mini-high :)

This is basically multi-objective. Maybe some ideas relating to the metrics here can become behavioral characteristics to guide a Quality Diversity evolutionary search.

In [14]:
# grids is the images of the rollout before it stabilizes or hits the max steps
@jax.jit
def fitness(grids):
    """
    Fitness function that rewards sustained, organized evolution.
    It encourages coherent movement while penalizing high-frequency random noise.
    
    Assumes:
      - grids shape: (time, height, width, channels)
      - pixel values: floats in [0, 1]
    """
    grids = jnp.array(grids)

    # A simple 3x3 Gaussian blur to filter out high-frequency noise.
    def blur(frame):
        kernel = jnp.array([[1., 2., 1.],
                            [2., 4., 2.],
                            [1., 2., 1.]], dtype=grids.dtype)
        kernel = kernel / jnp.sum(kernel)
        frame = frame[None, ...]  # Add batch dim: (1, H, W, C)
        kernel = kernel[..., None, None]  # (3, 3, 1, 1)
        channels = frame.shape[-1]
        blurred_channels = []
        for c in range(channels):
            channel = frame[..., c:c+1]
            blurred = jax.lax.conv_general_dilated(
                channel, kernel,
                window_strides=(1, 1),
                padding="SAME",
                dimension_numbers=("NHWC", "HWIO", "NHWC")
            )
            blurred_channels.append(blurred)
        blurred = jnp.concatenate(blurred_channels, axis=-1)
        return blurred[0]  # Remove batch dim

    # Preprocess: blur all frames.
    blurred_grids = jax.vmap(blur)(grids)

    # Compute motion differences between consecutive blurred frames.
    motion_diffs = blurred_grids[1:] - blurred_grids[:-1]
    motion_metric = jnp.mean(jnp.abs(motion_diffs))

    # Compute spatial gradients on the motion differences to assess coherence.
    def compute_gradient(frame):
        # Compute gradients along x and y directions.
        grad_x = jnp.abs(frame[:, 1:, :] - frame[:, :-1, :])  # shape (H, W-1, C)
        grad_y = jnp.abs(frame[1:, :, :] - frame[:-1, :, :])  # shape (H-1, W, C)
        # Crop to the overlapping region: (H-1, W-1, C)
        grad_x_cropped = grad_x[:-1, :, :]
        grad_y_cropped = grad_y[:, :-1, :]
        grad_magnitude = jnp.sqrt(grad_x_cropped**2 + grad_y_cropped**2)
        return jnp.mean(grad_magnitude)
    
    gradients = jax.vmap(compute_gradient)(motion_diffs)
    avg_gradient = jnp.mean(gradients)
    
    # Define a target gradient magnitude for organized motion.
    # Too high means random noise; too low might indicate stagnation.
    target_gradient = 0.05  # tweak this based on your observations
    organized_motion_reward = jnp.exp(-((avg_gradient - target_gradient)**2) / 0.001)
    organized_motion = motion_metric * organized_motion_reward

    # Shannon entropy on blurred frames.
    def compute_entropy(frame):
        flat = frame.reshape(-1)
        # Discretize [0, 1] into 256 bins.
        bin_indices = jnp.clip(jnp.floor(flat * 256), 0, 255).astype(jnp.int32)
        hist = jnp.zeros(256, dtype=jnp.int32).at[bin_indices].add(1)
        total = jnp.sum(hist)
        prob = hist / total
        return -jnp.sum(prob * jnp.log(prob + 1e-6))
    
    entropies = jax.vmap(compute_entropy)(blurred_grids)
    global_entropy = jnp.mean(entropies)
    
    # Sustained innovation: average change in entropy over time.
    sustained_innovation = jnp.mean(jnp.abs(entropies[1:] - entropies[:-1]))
    
    # Kolmogorov complexity proxy via compressibility:
    avg_grid = jnp.mean(grids, axis=0)
    avg_grid_blurred = blur(avg_grid)
    avg_grid_entropy = compute_entropy(avg_grid_blurred)
    compressibility = global_entropy - avg_grid_entropy
    target_compressibility = 5.0  # adjust as needed
    kolmogorov_complexity = jnp.exp(-((compressibility - target_compressibility) ** 2))

    # Combine metrics into a final fitness score.
    fitness_score = (
        0.25 * sustained_innovation +
        0.25 * global_entropy +
        0.25 * kolmogorov_complexity +
        0.25 * organized_motion
    )
    
    return fitness_score

@jax.jit
def old_fitness(grids):
    # Assume grids shape is (time, height, width, channels)
    grids = jnp.array(grids)
    
    # 1. Temporal Variation (Behavioral Richness)
    temporal_variance = jnp.mean(jnp.var(grids, axis=0))
    
    # 2. Spatial Variation (Frame-by-frame diversity)
    spatial_variance = jnp.mean(jnp.var(grids, axis=(1, 2)))
    
    # 3. Novelty (Frame-to-frame changes to avoid trivial repetition)
    diff = grids[1:] - grids[:-1]
    novelty = jnp.mean(jnp.sqrt(jnp.sum(diff**2, axis=(1,2,3))))
    
    # 4. Stability (Lyapunov-inspired measure to balance chaos vs. stagnation)
    epsilon = 1e-6
    lyapunov = jnp.mean(jnp.log(jnp.abs(diff) + epsilon))
    target_lyapunov = 0.0  # Ideally, we want moderate sensitivity
    stability = jnp.exp(-jnp.abs(lyapunov - target_lyapunov))
    
    # 5. Structural Complexity (Proxy via entropy on the averaged grid)
    flattened = grids.reshape(grids.shape[0], -1, grids.shape[-1])
    avg_grid = jnp.mean(flattened, axis=0)
    entropy = jnp.mean(jnp.var(avg_grid, axis=-1))
    
    # 6. Pattern Organization (Reward intermediate spatial gradients)
    # Compute gradients on the last frame to capture local coherence.
    last_grid = grids[-1]
    grad_x = jnp.abs(last_grid[:, 1:, :] - last_grid[:, :-1, :])
    grad_y = jnp.abs(last_grid[1:, :, :] - last_grid[:-1, :, :])
    mean_grad = (jnp.mean(grad_x) + jnp.mean(grad_y)) / 2.0
    # Target gradient value is a hyperparameter:
    # too low means uniform; too high means noise. Aim for a "sweet spot".
    target_gradient = 0.5  
    pattern_org = jnp.exp(-jnp.abs(mean_grad - target_gradient))
    
    # Combine metrics with weights tuned to favor organized emergent evolution.
    fitness_score = (
        0.20 * temporal_variance +   # Continuous change without stagnation
        0.15 * spatial_variance +    # Encourage diverse spatial patterns
        0.20 * novelty +             # Reward non-trivial, evolving transitions
        0.15 * stability +           # Balance between chaos and order
        0.10 * entropy +             # Reward complexity that isn't purely random
        0.20 * pattern_org           # Direct incentive for organized local structure
    )
    
    return fitness_score

In [15]:
nca = CellularAutomataNet(None)

# print how many parameters the model has
print(nca)
# for x in jax.tree_util.tree_leaves(nnx.state(nca, nnx.Param)):
#     print(x.shape)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(nnx.state(nca, nnx.Param)))
print(f'The model has {param_count} parameters.')

# # create a random grid of size (height, width, channels)
# grid = jax.random.uniform(jax.random.PRNGKey(0), (100, 100, nca.num_channels))
# grid = grid.at[grid[:, :, 3] < 0.99].set(0)  # set alpha channel to 0 for black pixels
# visualize(nca.run_until_stable(grid, max_steps=100, seed=0, mask_prob=0.5), fps=10, scale=5)

CellularAutomataNet(
  num_channels=4,
  conv1=Conv(
    kernel_shape=(3, 3, 4, 4),
    kernel=Param(
      value=Array(shape=(3, 3, 4, 4), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(4,), dtype=float32)
    ),
    in_features=4,
    out_features=4,
    kernel_size=(3, 3),
    strides=1,
    padding='VALID',
    input_dilation=1,
    kernel_dilation=1,
    feature_group_count=1,
    use_bias=True,
    mask=None,
    dtype=<class 'jax.numpy.float16'>,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x7f70c5badda0>,
    bias_init=<function zeros at 0x7f70d40df9c0>,
    conv_general_dilated=<function conv_general_dilated at 0x7f70d467aac0>
  ),
  linear1=Linear(
    kernel=Param(
      value=Array(shape=(4, 4), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(4,), dtype=float32)
    ),
    in_features=4,
    out_features=4,
    use_bias=True,
    dtype=<class 'jax.numpy.float16'>,

#### basic genetic algorithm to run evolution using fitness function

initial grid matters but for now we will use the same grid across all NCA

In [16]:
pop_size = 64

rng = jax.random.PRNGKey(6)

nca_pop = jax.random.normal(rng, (pop_size, param_count)) * 1 # initialize the population with small random values

initial_grid = jnp.zeros((100, 100, 4), dtype=jnp.float16)
# # pick a random pixel to not be black
# seed_pixel = jax.random.randint(rng, (2,), 0, 100)
# initial_grid = initial_grid.at[seed_pixel[0], seed_pixel[1], 3:4].set(1.0)  # set the pixel to alive
initial_grid = initial_grid.at[50, 50, 3:4].set(1.0)  # set the center pixel to be alive
# initial_grid = jax.random.uniform(rng, (100, 100, 3), dtype=jnp.float16)  # random initial grid (sort of necessary if stochastic cell updates are not used)

@jax.jit
def eval_population(nca_pop, seed=0):
    nca_static = CellularAutomataNet(None)  # create a single static instance for all evaluations
    @jax.jit
    def eval_one(params):
        model = nca_static.set_params(params)  # update parameters of the static model
        grids = model.run_until_stable(
            initial_grid,
            # max_steps=500,
            max_steps=200,
            seed=seed,  # using same seed continuously
            mask_prob=0.5 # stochastic cell update possibly makes finding random noisy NCA too easy
        )
        return grids, fitness(grids)
    # evaluate the population
    grids, fitness_scores = jax.vmap(eval_one)(nca_pop)
    return grids, fitness_scores

# def ask():
#     return nca_pop

# extremely basic crossover and mutation, likely converges easily
def tell(nca_pop, fitness_scores, rng, elite_ratio=0.75, mutation_rate=0.01):
    # sort the population by fitness scores
    sorted_indices = jnp.argsort(fitness_scores, descending=True) # we are maximizing fitness
    sorted_population = nca_pop[sorted_indices]

    # select the top half of the population
    selected_population = sorted_population[:int(pop_size * elite_ratio)]
    
    # create new population by crossover and mutation
    new_population = [elite for elite in selected_population] # keep the elite so fitness doesn't get worse
    for i in range(pop_size - len(selected_population)):
        # Choose parents randomly from the selected population
        rng, parent1_key = jax.random.split(rng)
        # rng, parent2_key = jax.random.split(rng)
        parent1_idx = jax.random.randint(parent1_key, (), 0, pop_size // 2)
        # parent2_idx = jax.random.randint(parent2_key, (), 0, pop_size // 2)
        # parent1 = selected_population[parent1_idx]
        # parent2 = selected_population[parent2_idx]

        # # do slerp crossover, simple average does not work well.
        # # calculate cosine similarity between parents
        # norm_parent1 = parent1 / (jnp.linalg.norm(parent1) + 1e-8)
        # norm_parent2 = parent2 / (jnp.linalg.norm(parent2) + 1e-8)
        # dot_product = jnp.sum(norm_parent1 * norm_parent2)
        # dot_product = jnp.clip(dot_product, -0.9999, 0.9999)  # prevent numerical issues
        
        # # generate random interpolation parameter
        # t = jax.random.uniform(rng, shape=(), minval=0, maxval=1)
        
        # # compute angle between vectors
        # omega = jnp.arccos(dot_product)
        
        # # perform slerp
        # s0 = jnp.sin((1.0 - t) * omega) / jnp.sin(omega)
        # s1 = jnp.sin(t * omega) / jnp.sin(omega)
        # child = s0 * parent1 + s1 * parent2

        # trying with no crossover, just mutation of one parent
        child = selected_population[parent1_idx]
        
        # add mutation
        rng, mutation_key = jax.random.split(rng)
        mutation = jax.random.normal(mutation_key, (param_count,)) * mutation_rate
        child = child + mutation

        rng, _ = jax.random.split(rng)
        new_population.append(child)
    
    return jnp.array(new_population)

#### evolutionary loop

In [17]:
best_fitness = -1
best_params = None
best_grids = None

generations = 500
gens_since_last_improvement = 0
# evolution settings
mutation_rate = 0.03
elite_ratio = 0.25

In [18]:
for gen in range(generations):
    print(f'Generation {gen}')

    # measure population diversity
    diversity = jnp.mean(jnp.std(nca_pop, axis=0)) # simple std of the population to detect if the population is collapsing
    print(f'\tPopulation diversity: {diversity}')

    grids, fitness_scores = eval_population(nca_pop, seed=gen)
    print(f'\tBest fitness: {jnp.max(fitness_scores)}')
    
    # find the best fitness score and corresponding parameters
    max_fitness = jnp.max(fitness_scores)
    if max_fitness > best_fitness:
        best_fitness = max_fitness
        best_params = nca_pop[jnp.argmax(fitness_scores)]
        best_grids = grids[jnp.argmax(fitness_scores)]
        print(f'\tNew best fitness!')
        gens_since_last_improvement = 0
    else:
        gens_since_last_improvement += 1
        print(f'\tNo improvement for {gens_since_last_improvement} generations.')
    
    # update the population
    if gens_since_last_improvement > 10 or diversity < 0.2:
        # if no improvement for 20 generations, increase mutation rate
        mutation_rate += 0.01
        elite_ratio += 0.01
    else:
        mutation_rate = 0.03
        elite_ratio = 0.25
    
    elite_ratio = jnp.clip(elite_ratio, 0.01, 1.0) # keep elite ratio between 0.01 and 1.0
    mutation_rate = jnp.clip(mutation_rate, 0.01, 1.0) # keep mutation rate between 0.01 and 1.0

    nca_pop = tell(nca_pop, fitness_scores, jax.random.PRNGKey(gen), elite_ratio=elite_ratio, mutation_rate=mutation_rate)

    # this seems to prevent out of memory issues on my GPU by letting these get garbage collected
    del grids
    del fitness_scores


Generation 0
	Population diversity: 0.9942861795425415
	Best fitness: 0.5653204321861267
	New best fitness!
Generation 1
	Population diversity: 0.9007865190505981
	Best fitness: 0.6455160975456238
	New best fitness!
Generation 2
	Population diversity: 0.7625061869621277
	Best fitness: 0.6974807977676392
	New best fitness!
Generation 3
	Population diversity: 0.7078838348388672
	Best fitness: 0.7016697525978088
	New best fitness!
Generation 4
	Population diversity: 0.5490154027938843
	Best fitness: 0.702158510684967
	New best fitness!
Generation 5
	Population diversity: 0.5504379868507385
	Best fitness: 0.7564968466758728
	New best fitness!
Generation 6
	Population diversity: 0.04597574844956398
	Best fitness: 0.7784373760223389
	New best fitness!
Generation 7
	Population diversity: 0.054064493626356125
	Best fitness: 0.7865030169487
	New best fitness!
Generation 8
	Population diversity: 0.06376917660236359
	Best fitness: 0.7847586274147034
	No improvement for 1 generations.
Generation 9

KeyboardInterrupt: 

In [19]:
visualize(best_grids, fps=10, scale=5)

  self.pid = _fork_exec(
