# Recurrent Residual Convolutional Neural Network

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.state import state_to_alive
from cax.core.update.residual_update import ResidualUpdate
from datasets import load_dataset
from flax import nnx

## Configuration

In [None]:
seed = 0

num_spatial_dims = 2
channel_size = 16
perception_size = 64
num_kernels = 3
perceive_hidden_layer_sizes = (perception_size, perception_size)
update_hidden_layer_sizes = (128, 128)
cell_dropout_rate = 0.5

batch_size = 8
num_steps = 128
learning_rate = 1e-3

key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [None]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)

mediapy.show_images(image_train[:8], width=64, height=64)

## Init state

In [6]:
def add_noise(target, alpha, key):
	noise = jax.random.normal(key, target.shape)
	noisy_image = (1 - alpha) * target + alpha * noise
	return jnp.clip(noisy_image, 0.0, 1.0)


def init_state(key):
	state = jnp.zeros(image_train.shape[1:3] + (channel_size,))

	# Sample a target image
	sample_key, alpha_key, noise_key = jax.random.split(key, 3)
	image_index = jax.random.choice(sample_key, image_train.shape[0])
	image = image_train[image_index]

	# Add noise
	alpha = jax.random.uniform(alpha_key)
	noise = jax.random.normal(noise_key, image.shape)
	noisy_target = (1 - alpha) * image + alpha * noise

	return state.at[..., -1:].set(noisy_target), image_index

## Model

In [7]:
perceive = ConvPerceive(channel_size, perception_size, perceive_hidden_layer_sizes, rngs)
update = ResidualUpdate(
	num_spatial_dims,
	channel_size,
	perception_size,
	update_hidden_layer_sizes,
	rngs,
	cell_dropout_rate=cell_dropout_rate,
)
ca = CA(perceive, update)

In [8]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree_util.tree_reduce(lambda x, y: x + y.size, params, 0))

Number of params: 109840


## Train

### Optimizer

In [9]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=2_000)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

update_params = nnx.All(nnx.Param, nnx.PathContains("update"))
optimizer = nnx.Optimizer(ca, optimizer, wrt=update_params)

### Loss

In [10]:
def mse(state, target):
	return jnp.mean(jnp.square(state_to_alive(state) - target))

In [11]:
@nnx.jit
def loss_fn(ca, state, target):
	state = nnx.vmap(lambda state: ca(state, num_steps=num_steps))(state)
	loss = mse(state, target)
	return loss

### Train step

In [12]:
@nnx.jit
def train_step(ca, optimizer, key):
	keys = jax.random.split(key, batch_size)
	current_state, target_index = jax.vmap(init_state)(keys)
	target = image_train[target_index]

	loss, grad = nnx.value_and_grad(loss_fn, argnums=nnx.DiffState(0, update_params))(ca, current_state, target)
	optimizer.update(grad)

	return loss

### Main loop

In [None]:
for i in range(8_192):
	key, subkey = jax.random.split(key)
	loss = train_step(ca, optimizer, subkey)
	if i % 128 == 0:
		print(f"Step {i}: loss = {loss}")

## Visualize

In [69]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
state, _ = jax.vmap(init_state)(keys)
state = nnx.vmap(lambda state: ca(state, num_steps=2 * num_steps, all_steps=True))(state)

In [70]:
mediapy.show_videos(jnp.squeeze(state_to_alive(state)), width=128, height=128, codec="gif")