# Recurrent Residual Convolutional Neural Network

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install -U "cax[examples]"

## Import

In [None]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.state import state_to_alive
from cax.core.update.residual_update import ResidualUpdate
from datasets import load_dataset
from flax import nnx
from tqdm.auto import tqdm

## Configuration

In [16]:
seed = 0

num_spatial_dims = 2
channel_size = 16
perception_size = 64
update_hidden_layer_sizes = (128, 128)
cell_dropout_rate = 0.5

batch_size = 8
num_steps = 64
learning_rate = 1e-3

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [17]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)

mediapy.show_images(image_train[:8], width=128, height=128)

## Init state

In [18]:
def add_noise(target, alpha, key):
	noise = jax.random.normal(key, target.shape)
	noisy_image = (1 - alpha) * target + alpha * noise
	return jnp.clip(noisy_image, 0.0, 1.0)


def init_state(key):
	state = jnp.zeros(image_train.shape[1:3] + (channel_size,))

	# Sample a target image
	sample_key, alpha_key, noise_key = jax.random.split(key, 3)
	image_index = jax.random.choice(sample_key, image_train.shape[0])
	image = image_train[image_index]

	# Add noise
	alpha = jax.random.uniform(alpha_key)
	noise = jax.random.normal(noise_key, image.shape)
	noisy_target = (1 - alpha) * image + alpha * noise

	return state.at[..., -1:].set(noisy_target), image_index

## Model

In [28]:
perceive = ConvPerceive(
	channel_size=channel_size,
	perception_size=perception_size,
	rngs=rngs,
)
update = ResidualUpdate(
	num_spatial_dims=num_spatial_dims,
	channel_size=channel_size,
	perception_size=perception_size,
	hidden_layer_sizes=update_hidden_layer_sizes,
	rngs=rngs,
	cell_dropout_rate=cell_dropout_rate,
)
ca = CA(perceive, update)

In [29]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree.reduce(lambda x, y: x + y.size, params, 0))

Number of params: 109840


## Train

### Optimizer

In [30]:
lr_sched = optax.linear_schedule(
	init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=2_000
)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

update_params = nnx.All(nnx.Param, nnx.PathContains("update"))
optimizer = nnx.Optimizer(ca, optimizer, wrt=update_params)

### Loss

In [31]:
def mse(state, target):
	return jnp.mean(jnp.square(state_to_alive(state) - target))

In [32]:
@nnx.jit
def loss_fn(ca, state, target):
	state = nnx.vmap(lambda state: ca(state, num_steps=num_steps))(state)
	loss = mse(state, target)
	return loss

### Train step

In [33]:
@nnx.jit
def train_step(ca, optimizer, key):
	keys = jax.random.split(key, batch_size)
	current_state, target_index = jax.vmap(init_state)(keys)
	target = image_train[target_index]

	loss, grad = nnx.value_and_grad(loss_fn, argnums=nnx.DiffState(0, update_params))(
		ca, current_state, target
	)
	optimizer.update(grad)

	return loss

### Main loop

In [None]:
num_train_steps = 8_192
print_interval = 128

pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []
for i in pbar:
	key, subkey = jax.random.split(key)
	loss = train_step(ca, optimizer, subkey)
	losses.append(loss)

	if i % print_interval == 0 or i == num_train_steps - 1:
		avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
		pbar.set_postfix({"Average Loss": f"{avg_loss:.3e}"})

## Visualize

In [42]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
state, target_index = jax.vmap(init_state)(keys)

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=8)(
	nnx.vmap(
		lambda ca, state: ca(state, num_steps=2 * num_steps, all_steps=True),
		in_axes=(state_axes, 0),
	)
)(ca, state)

In [43]:
mediapy.show_images(image_train[target_index], width=128, height=128)
mediapy.show_videos(jnp.squeeze(state_to_alive(state)), width=128, height=128, codec="gif")