# Open-Ended Cellular Automata

## Import

In [1]:
from functools import partial

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax import traverse_util
from flax.training.train_state import TrainState
import optax

from common import cell

In [2]:
cell_state_size = 16
n_perceive_free = 0
update_size = 256
fire_rate = 0.5

n_iterations = 64
batch_size = 32
learning_rate = 2e-3

## CA

In [3]:
class CA(nn.Module):
	cell_state_size: int
	n_perceive_free: int
	update_size: int
	fire_rate: float

	@nn.compact
	def __call__(self, random_key, x, step_size=1.0):
		pre_life_mask = cell.get_living_mask(x)

		# Perceive with depthwise convolution
		y = nn.Conv(features=3*self.cell_state_size, kernel_size=(3, 3), padding="SAME", feature_group_count=self.cell_state_size, use_bias=False, name="perceive_frozen")(x)
		if self.n_perceive_free > 0:
			y_free = nn.Conv(features=self.n_perceive_free*self.cell_state_size, kernel_size=(3, 3), padding="SAME", feature_group_count=self.cell_state_size, use_bias=False, name="perceive_free")(x)
			y = jnp.concatenate([y, y_free], axis=-1)

		# Update
		dx = nn.relu(nn.Conv(features=self.update_size, kernel_size=(1, 1))(y))
		dx = nn.Conv(features=self.cell_state_size, kernel_size=(1, 1))(dx) * step_size  # not initialized to zeros
		update_mask = jax.random.uniform(random_key, shape=(*x.shape[:-1], 1), minval=0., maxval=1.) <= self.fire_rate
		x += dx * update_mask

		post_life_mask = cell.get_living_mask(x)
		life_mask = pre_life_mask & post_life_mask
		return jnp.clip(x * life_mask, a_min=0., a_max=1.)

	@partial(jax.jit, static_argnames=("self",))
	def _get_kernel(self, angle):
		identify = jnp.array([0., 1., 0.])
		identify = jnp.outer(identify, identify)
		dx = jnp.outer(jnp.array([1., 2., 1.]), jnp.array([-1., 0., 1.])) / 8.0  # Sobel filter
		dy = dx.T
		c, s = jnp.cos(angle), jnp.sin(angle)
		kernel = jnp.stack([identify, c*dx-s*dy, s*dx+c*dy], axis=-1)[:, :, None, :]
		kernel = jnp.tile(kernel, (1, 1, 1, self.cell_state_size))
		return kernel

	def set_kernel(self, params, angle=0.):
		kernel = self._get_kernel(angle)
		params["params"]["perceive_frozen"]["kernel"] = kernel
		return params

	def get_perceive_mask(self, params):
		flat_params = traverse_util.flatten_dict(params, sep="/")
		flat_params = dict.fromkeys(flat_params, False)

		for key in flat_params:
			if "perceive_frozen" in key:
				flat_params[key] = True
		return traverse_util.unflatten_dict(flat_params, sep="/")

In [4]:
random_key = jax.random.PRNGKey(0)

ca = CA(cell_state_size=cell_state_size, n_perceive_free=n_perceive_free, update_size=update_size, fire_rate=fire_rate)

# Generate random cells_states
random_key, random_subkey = jax.random.split(random_key)
fake_cells_state = jax.random.uniform(random_subkey, (128, 128, cell_state_size), minval=0., maxval=1.)
fake_cells_state = fake_cells_state.at[..., :3].set(fake_cells_state[..., :3] * fake_cells_state[..., 3:4])

random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
params = ca.init(random_subkey_1, random_subkey_2, fake_cells_state)
params = ca.set_kernel(params)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print("Number of parameters in CA: ", param_count)

Number of parameters in CA:  17088


## Train state

In [32]:
mask = cell.make_ellipse_mask((64, 64), 128, 128, 32, 32)

In [36]:
# Train state
tx = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=learning_rate),)

# Define cells_states
cells_states = jnp.zeros((128, 128, cell_state_size))
cells_states = cells_states.at[64, 64, :].set(1.)
# random_key, random_subkey = jax.random.split(random_key)
# mask = cell.make_ellipse_mask((64, 64), 128, 128, 32, 32)
# cells_states = jax.random.uniform(random_subkey, (128, 128, cell_state_size), minval=0., maxval=1.)
# cells_states = cells_states.at[..., :3].set(cells_states[..., :3] * mask[..., None])
# cells_states = cells_states.at[..., 3].set(mask)

train_state = TrainState.create(
	apply_fn=ca.apply,
	params=cells_states,
	tx=tx)

# Train
@jax.jit
def loss_f(cell_states_before, cell_states_after):
	return jnp.mean(jnp.square(cell.to_rgba(cell_states_after) - cell.to_rgba(cell_states_before)), axis=(-1, -2, -3))

loss_log = []

@jax.jit
def scan_apply(carry, random_key):
	cells_states_before = carry
	cells_states_after = train_state.apply_fn(params, random_key, cells_states_before)
	return cells_states_after, ()

@partial(jax.jit, static_argnames=("n_iterations",))
def train_step(random_key, train_state, n_iterations):
	def loss_fn(cells_states_before):
		random_keys = jax.random.split(random_key, batch_size*n_iterations)
		random_keys = jnp.reshape(random_keys, (batch_size, n_iterations, -1))
		cells_states_before = jnp.repeat(cells_states_before[None, ...], repeats=batch_size, axis=0)
		cells_states_after, _ = jax.vmap(lambda x, y: jax.lax.scan(scan_apply, x, y, length=n_iterations))(cells_states_before, random_keys)
		return loss_f(cells_states_before, cells_states_after).mean(), cells_states_after

	(loss, cells_states_after,), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params)
	train_state = train_state.apply_gradients(grads=grads)

	return train_state, loss, cells_states_after

In [37]:
from common.utils import jnp2pil


for i in range(8000):
    random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
    train_state, loss, cells_states = train_step(random_subkey_2, train_state, int(n_iterations))

    if i % 100 == 0:
        print("Loss:", loss)
        image = cell.to_rgb(train_state.params)
        # save image
        jnp2pil(image).save("/project/output/image_{}.png".format(i))

Loss: 0.0068814633
Loss: 0.051341347
Loss: 0.32576028
Loss: 0.36982852
Loss: 0.22431275
Loss: 0.14005688
Loss: 0.087659374
Loss: 0.05493854
Loss: 0.034628578
Loss: 0.022136461
Loss: 0.014506729
Loss: 0.009813476
Loss: 0.0069492217
Loss: 0.005126725
Loss: 0.0039826455
Loss: 0.0032079953
Loss: 0.002634244
Loss: 0.0022537443
Loss: 0.0019717831
Loss: 0.001772495
Loss: 0.0016112648
Loss: 0.0014501424
Loss: 0.0013363919
Loss: 0.0012354015
Loss: 0.0011600284
Loss: 0.001076618
Loss: 0.001021172
Loss: 0.0009484771
Loss: 0.00088432024
Loss: 0.00081746094
Loss: 0.0007799427
Loss: 0.00075953826
Loss: 0.0007199351
Loss: 0.0006797118
Loss: 0.000642028
Loss: 0.0006243314
Loss: 0.0005899579
Loss: 0.00058605237
Loss: 0.0005426289
Loss: 0.0005248283
Loss: 0.0004880961
Loss: 0.000492913
Loss: 0.000465012
Loss: 0.0004721097
Loss: 0.00044435123
Loss: 0.0004253455
Loss: 0.00042433495
Loss: 0.0004025597
Loss: 0.0003865855
Loss: 0.00035615283
Loss: 0.0003601134
Loss: 0.00035608985
Loss: 0.00033351386
Loss: 0.

KeyboardInterrupt: 

In [31]:
from common.cell import make_ellipse_mask
from matplotlib import pyplot as plt


make_ellipse_mask((64, 64), 128, 128, 32, 32)

Array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]], dtype=bool)