# Variational Autoencoder

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.nn.vae import VAE, vae_loss
from datasets import load_dataset
from flax import nnx

## Configuration

In [None]:
seed = 42

spatial_dims = (28, 28)
features = (1, 32, 32)
latent_size = 8

batch_size = 32
learning_rate = 1e-2

key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [3]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)

mediapy.show_images(ds["train"]["image"][:8], width=64, height=64)

## Model

In [6]:
vae = VAE(spatial_dims, features, latent_size, rngs)

In [7]:
params = nnx.state(vae, nnx.Param)
print("Number of params:", jax.tree_util.tree_reduce(lambda x, y: x + y.size, params, 0))

Number of params: 2518513


## Train

### Optimizer

In [8]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.01 * learning_rate, transition_steps=8_192)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)
optimizer = nnx.Optimizer(vae, optimizer)

### Loss

In [9]:
@nnx.jit
def loss_fn(vae, image, key):
	image_recon, mean, logvar = vae(image, key)
	return vae_loss(image_recon, image, mean, logvar)

### Train step

In [10]:
@nnx.jit
def train_step(vae, optimizer, key):
	sample_key, loss_key = jax.random.split(key)
	image_index = jax.random.choice(sample_key, image_train.shape[0], shape=(batch_size,))
	image = image_train[image_index]

	loss, grad = nnx.value_and_grad(loss_fn)(vae, image, loss_key)
	optimizer.update(grad)

	return loss

### Main loop

In [None]:
for i in range(8_192):
	key, subkey = jax.random.split(key)
	loss = train_step(vae, optimizer, subkey)
	if i % 128 == 0:
		print(f"Step {i}: loss = {loss}")

## Visualize

In [34]:
key, subkey = jax.random.split(key)
z = jax.random.normal(subkey, shape=(latent_size,))
image = vae.generate(z)

mediapy.show_image(image, width=128, height=128)

In [27]:
key, subkey = jax.random.split(key)
image_index = jax.random.choice(subkey, image_test.shape[0], shape=(8,))
image = image_test[image_index]

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 8)
image_recon, _, _ = nnx.vmap(vae)(image, keys)

mediapy.show_images(image, width=64, height=64)
mediapy.show_images(jax.nn.sigmoid(image_recon), width=64, height=64)