# Growing Unsupervised Neural Cellular Automata

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.depthwise_conv_perceive import DepthwiseConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.state import state_to_alive
from cax.core.update.nca_update import NCAUpdate
from cax.nn.pool import Pool
from cax.nn.vae import Encoder
from datasets import load_dataset
from flax import nnx

## Configuration

In [None]:
seed = 0

spatial_dims = (28, 28)
features = (1, 32, 32)
latent_size = 8

channel_size = 32
num_kernels = 3
hidden_size = 256
cell_dropout_rate = 0.5

pool_size = 1_024
batch_size = 8
num_steps = 64
learning_rate = 1e-3

key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [4]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)

mediapy.show_images(image_train[:8], width=64, height=64)

## Init state

In [11]:
def init_state(key):
	state_shape = image_train.shape[1:3] + (channel_size,)

	state = jnp.zeros(state_shape)
	mid = (size // 2 for size in state_shape[:-1])
	state = state.at[*mid, -1].set(1.0)

	target_index = jax.random.choice(key, image_train.shape[0])
	return state, target_index

## Model

In [12]:
perceive = DepthwiseConvPerceive(channel_size, rngs)
update = NCAUpdate(
	channel_size, latent_size + num_kernels * channel_size, (hidden_size,), rngs, cell_dropout_rate=cell_dropout_rate
)
encoder = Encoder(spatial_dims, features, latent_size, rngs)


class UnsupervisedCA(CA):
	encoder: Encoder

	def __init__(self, perceive, update, encoder):
		super().__init__(perceive, update)

		self.encoder = encoder

	def encode(self, target, key):
		mean, logvar = self.encoder(target)
		target_enc = mean + jax.random.normal(key, mean.shape) * jnp.exp(0.5 * logvar)
		return target_enc

In [13]:
kernel = jnp.concatenate([identity_kernel(ndim=2), grad_kernel(ndim=2)], axis=-1)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
perceive.depthwise_conv.kernel = nnx.Param(kernel)

In [14]:
ca = UnsupervisedCA(perceive, update, encoder)

In [15]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree_util.tree_reduce(lambda x, y: x + y.size, params, 0))

Number of params: 2530832


## Train

### Pool

In [17]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, pool_size)
state, target_index = jax.vmap(lambda key: init_state(key))(keys)

pool = Pool.create({"state": state, "target_index": target_index})

### Optimizer

In [18]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=50_000)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

grad_params = nnx.All(nnx.Param, nnx.Any(nnx.PathContains("update"), nnx.PathContains("encoder")))
optimizer = nnx.Optimizer(ca, optimizer, wrt=grad_params)

### Loss

In [19]:
def mse(state, target):
	return jnp.mean(jnp.square(state_to_alive(state) - target))

In [20]:
@nnx.jit
def loss_fn(ca, state, target, key):
	enc_key, randint_key = jax.random.split(key)

	target_enc = ca.encode(target, enc_key)

	state = nnx.vmap(lambda state, input: ca(state, input, num_steps=num_steps, all_steps=True))(state, target_enc)

	index = jax.random.randint(randint_key, (state.shape[0],), num_steps // 2, num_steps)
	state = state[jnp.arange(state.shape[0]), index]

	loss = mse(state, target)
	return loss, state

### Train step

In [21]:
@nnx.jit
def train_step(ca, optimizer, pool, key):
	sample_key, init_state_key, loss_key = jax.random.split(key, 3)

	# Sample from pool
	pool_index, batch = pool.sample(sample_key, batch_size=batch_size)
	current_state = batch["state"]
	current_target_index = batch["target_index"]
	current_target = image_train[current_target_index]

	# Sort by descending loss
	sort_index = jnp.argsort(jax.vmap(mse)(current_state, current_target), descending=True)
	pool_index = pool_index[sort_index]
	current_state = current_state[sort_index]
	current_target_index = current_target_index[sort_index]

	# Sample a new target to replace the worst
	new_state, new_target_index = init_state(init_state_key)
	current_state = current_state.at[0].set(new_state)
	current_target_index = current_target_index.at[0].set(new_target_index)
	current_target = image_train[current_target_index]

	(loss, current_state), grad = nnx.value_and_grad(loss_fn, has_aux=True, argnums=nnx.DiffState(0, grad_params))(
		ca, current_state, current_target, loss_key
	)
	optimizer.update(grad)

	pool = pool.update(pool_index, {"state": current_state, "target_index": current_target_index})
	return loss, pool

### Main loop

In [None]:
for i in range(8_192):
	key, subkey = jax.random.split(key)
	loss, pool = train_step(ca, optimizer, pool, subkey)
	if i % 128 == 0:
		print(f"Step {i}: loss = {loss}")

## Visualize

In [59]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
state, image_index = jax.vmap(init_state)(keys)

key, subkey = jax.random.split(key)
target = image_train[image_index]
target_enc = ca.encode(target, subkey)

state = nnx.vmap(lambda state, input: ca(state, input, num_steps=2 * num_steps, all_steps=True))(state, target_enc)

mediapy.show_images(target, width=128, height=128)
mediapy.show_videos(jnp.squeeze(state_to_alive(state)), width=128, height=128, codec="gif")

### Interpolation

In [57]:
key, subkey = jax.random.split(key)
image_index = jax.random.choice(subkey, image_train.shape[0], shape=(2,))
image = image_train[image_index]

key, subkey = jax.random.split(key)
target_enc = ca.encode(image, subkey)

In [58]:
alphas = jnp.linspace(0.0, 1.0, 8)
target_encs = jnp.array([alpha * target_enc[0] + (1 - alpha) * target_enc[1] for alpha in alphas])

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 8)
state, _ = jax.vmap(init_state)(keys)
state = nnx.vmap(lambda state, input: ca(state, input, num_steps=num_steps))(state, target_encs)

mediapy.show_images(state_to_alive(state), width=128, height=128)