# Self-classifying MNIST Digits

In [None]:
%pip install -U "cax[examples]"

## Import

In [5]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.update.nca_update import NCAUpdate
from cax.nn.pool import Pool
from cax.types import Input, State
from datasets import load_dataset
from flax import nnx
from tqdm.auto import tqdm

## Configuration

In [None]:
seed = 0

channel_size = 20
perception_size = 80
perceive_hidden_layers_sizes = ()
update_hidden_layers_sizes = (80,)
cell_dropout_rate = 0.5

pool_size = 1_024
batch_size = 16
num_steps = 20
learning_rate = 1e-3

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [4]:
ds = load_dataset("ylecun/mnist")

mediapy.show_images(ds["train"]["image"][:8], width=128, height=128)

In [23]:
color_lookup = (
	jnp.array(
		[
			[128, 0, 0],
			[230, 25, 75],
			[70, 240, 240],
			[210, 245, 60],
			[250, 190, 190],
			[170, 110, 40],
			[170, 255, 195],
			[165, 163, 159],
			[0, 128, 128],
			[128, 128, 0],
			[0, 0, 0],  # Default for digits
			[255, 255, 255],  # Background
		]
	)
	/ 255
)


def create_cells_label(image, label):
	cells_label = jnp.zeros(image.shape[:-1] + (10,))

	expanded_label = jnp.broadcast_to(label[..., None, None], image.shape[:-1])
	mask = image[..., 0] >= 0.1

	cells_label = cells_label.at[mask].set(jax.nn.one_hot(expanded_label[mask], 10))
	return cells_label


def label_to_color(image, cells_label):
	is_gray = (image[..., 0] > 0.1).astype(jnp.float32)
	is_not_gray = 1.0 - is_gray

	cells_label = cells_label * jnp.expand_dims(is_gray, axis=-1)

	black_and_white = jnp.stack([is_gray, is_not_gray], axis=-1) * 0.01
	cells_label = jnp.concatenate([cells_label, black_and_white], axis=-1)

	cells_color = color_lookup[jnp.argmax(cells_label, axis=-1)]
	return cells_color


def find_different_digits(images, labels, cells_labels):
	digits = []
	for i in range(10):
		mask = labels == i
		if jnp.any(mask):
			idx = jnp.argmax(mask)
			digits.append(label_to_color(images[idx], cells_labels[idx]))
	return digits


def state_to_color(state):
	# Extract classification probabilities and image
	probs, _, image = jnp.split(state, (10, channel_size - 1), axis=-1)

	# Create a mask for non-background pixels
	is_gray = (image[..., -1] > 0.1).astype(jnp.float32)
	is_not_gray = 1.0 - is_gray

	# Apply the mask to the probabilities
	cells_label = probs * jnp.expand_dims(is_gray, axis=-1)

	# Add black and white channels
	black_and_white = jnp.stack([is_gray, is_not_gray], axis=-1) * 0.01
	cells_label = jnp.concatenate([cells_label, black_and_white], axis=-1)

	# Get the most likely class for each pixel and use the color lookup table
	cells_color = color_lookup[jnp.argmax(cells_label, axis=-1)]
	return cells_color

In [24]:
image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
label_train = jnp.array(ds["train"]["label"])

image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)
label_test = jnp.array(ds["test"]["label"])

In [25]:
cells_label_train = create_cells_label(image_train, label_train)
cells_label_test = create_cells_label(image_test, label_test)

colored_digits = find_different_digits(image_train, label_train, cells_label_train)

In [26]:
mediapy.show_images(colored_digits, width=64, height=64)

In [27]:
def init_state(key):
	state_shape = image_train.shape[1:3] + (channel_size,)
	state = jnp.zeros(state_shape)

	image_index = jax.random.choice(key, image_train.shape[0])
	image = image_train[image_index]
	state = state.at[..., -1:].set(image)
	return state, image_index

## Model

In [28]:
perceive = ConvPerceive(channel_size, perception_size, perceive_hidden_layers_sizes, rngs)
update = NCAUpdate(channel_size, perception_size, update_hidden_layers_sizes, rngs, cell_dropout_rate=cell_dropout_rate)


class SelfClassifyingCA(CA):
	@nnx.jit
	def step(self, state: State, input: Input = None) -> State:
		_, image = jnp.split(state, (channel_size - 1,), axis=-1)

		perception = self.perceive(state)
		state = self.update(state, perception, input)

		state, _ = jnp.split(state, (channel_size - 1,), axis=-1)
		state = jnp.concatenate([state, image], axis=-1)
		return state


ca = SelfClassifyingCA(perceive, update)

In [29]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree.reduce(lambda x, y: x + y.size, params, 0))

Number of params: 22500


## Train

### Pool

In [30]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, pool_size)
state, image_index = jax.vmap(init_state)(keys)

pool = Pool.create({"state": state, "image_index": image_index})

### Optimizer

In [31]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.01 * learning_rate, transition_steps=100_000)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

optimizer = nnx.Optimizer(ca, optimizer)

### Loss

In [32]:
def l2(state, cells_label):
	l2_loss = jnp.sum(jnp.square(state[..., :10] - cells_label), axis=(-1, -2, -3)) / 2
	return jnp.mean(l2_loss)


def ce(state, cells_label):
	integer_label = jnp.argmax(cells_label, axis=-1)
	return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(state[..., :10], integer_label))

In [33]:
@nnx.jit
def loss_fn(ca, state, cells_label):
	state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
	state = nnx.split_rngs(splits=batch_size)(
		nnx.vmap(
			lambda ca, state: ca(state, num_steps=num_steps),
			in_axes=(state_axes, 0),
		)
	)(ca, state)

	loss = l2(state, cells_label)
	return loss, state

### Train step

In [34]:
@nnx.jit
def train_step(ca, optimizer, pool, key):
	sample_key, init_state_key = jax.random.split(key)

	# Sample from pool
	pool_index, batch = pool.sample(sample_key, batch_size=batch_size)
	current_state = batch["state"]
	current_image_index = batch["image_index"]

	# A quarter of the batch is replaced with new images
	new_state, new_image_index = init_state(init_state_key)
	current_state = current_state.at[: batch_size // 4].set(new_state)
	current_image_index = current_image_index.at[: batch_size // 4].set(new_image_index)

	# Get images
	current_cells_label = cells_label_train[current_image_index]

	(loss, current_state), grad = nnx.value_and_grad(loss_fn, has_aux=True)(ca, current_state, current_cells_label)
	optimizer.update(grad)

	pool = pool.update(pool_index, {"state": current_state, "image_index": current_image_index})
	return loss, pool

### Main loop

In [None]:
num_train_steps = 8_192
print_interval = 128

pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []

for i in pbar:
	key, subkey = jax.random.split(key)
	loss, pool = train_step(ca, optimizer, pool, subkey)
	losses.append(loss)

	if i % print_interval == 0 or i == num_train_steps - 1:
		avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
		pbar.set_postfix({"Average Loss": f"{avg_loss:.6f}"})

## Visualize

In [21]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
new_state, _ = jax.vmap(init_state)(keys)

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=batch_size)(
	nnx.vmap(
		lambda ca, state: ca(state, num_steps=10 * num_steps, all_steps=True),
		in_axes=(state_axes, 0),
	)
)(ca, state)

mediapy.show_videos(state_to_color(state), width=128, height=128, codec="gif")