# Variational Autoencoder (VAE)

In [None]:
from typing import Sequence
import functools
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.linen as nn
import haiku as hk
import optax
import chex
from tqdm import tqdm
from utils import BatchManager, load_dataset, save_samples

## Configuration and Data Loading

In [None]:
# Define the configuration for each dataset
dataset_name = 'checkerboard'
dataset_configs = {
    'checkerboard': {
        'epochs': 500
    },
    'gaussian_mixture': {
        'epochs': 500
    },
    'pinwheel': {
        'epochs': 500
    },
    'spiral': {
        'epochs': 500
    }
}

config = dataset_configs[dataset_name]
config['batch_size'] = 128
config['learning_rate'] = 1e-3
config['beta'] = 0.002
config['enc_layer_dim'] = [128, 64]
config['dec_layer_dim'] = [64, 128]
config['latent_dim'] = 20
config['output_dim'] = 2
X_train, X_test = load_dataset(dataset_name)

## Model Definition

In [None]:
class Encoder(nn.Module):
    layer_dim: Sequence[int]
    latent_dim: int

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        for f in self.layer_dim:
            x = nn.Dense(f)(x)
            x = nn.swish(x)
        x = nn.Dense(self.latent_dim * 2)(x)
        mean = x[..., :self.latent_dim]
        log_var = x[..., self.latent_dim:]
        return mean, log_var

class Decoder(nn.Module):
    layer_dim: Sequence[int]
    output_dim: int

    @nn.compact
    def __call__(self, z: jax.Array) -> jax.Array:
        for f in self.layer_dim:
            z = nn.Dense(f)(z)
            z = nn.swish(z)
        x_recon = nn.Dense(self.output_dim)(z)
        return x_recon

class VAE(nn.Module):
    enc_layer_dim: Sequence[int]
    dec_layer_dim: Sequence[int]
    latent_dim: int
    output_dim: int

    def setup(self):
        self.encoder = Encoder(self.enc_layer_dim, self.latent_dim)
        self.decoder = Decoder(self.dec_layer_dim, self.output_dim)

    def __call__(self, x: jax.numpy.ndarray, key: chex.PRNGKey):
        mean, log_var = self.encoder(x)
        z = mean + jnp.exp(0.5 * log_var) * jax.random.normal(key, mean.shape)
        x_recon = self.decoder(z)
        return x_recon, mean, log_var
    
model = VAE(
    enc_layer_dim=config['enc_layer_dim'],
    dec_layer_dim=config['dec_layer_dim'],
    latent_dim=config['latent_dim'],
    output_dim=config['output_dim']
)

## Training Preparation
Set up the optimizer, loss functions, and other training utilities.## Define the VAE loss

In [None]:
optimizer = optax.adam(learning_rate=config['learning_rate'])
prng_seq = hk.PRNGSequence(jax.random.PRNGKey(0))

def vae_loss(params: chex.ArrayTree, batch: jax.Array, key: chex.PRNGKey):
    batch_recon, mean, log_var = model.apply(params, batch, key)
    recon_loss = jnp.mean(jnp.square(batch - batch_recon))  # Reconstruction loss
    kl_div = - jnp.mean( jnp.sum(1 + log_var - jnp.square(mean) - jnp.exp(log_var), axis=1))  # KL divergence
    return recon_loss + config['beta'] * kl_div

@jax.jit
def do_batch_update(batch: jax.Array, params: chex.ArrayTree, opt_state: optax.OptState, key: chex.PRNGKey) -> tuple[float, chex.ArrayTree, optax.OptState]:
    loss, grad = jax.value_and_grad(vae_loss)(params, batch, key)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

## Training Loop

In [None]:
params = model.init(next(prng_seq), X_train, next(prng_seq))
opt_state = optimizer.init(params)
bm = BatchManager(X_train, config['batch_size'], key=next(prng_seq))
train_losses = []
test_losses = []

In [None]:
for epoch in tqdm(range(config['epochs']), "Epoch"):
    batch_loss = 0
    for _ in range(bm.num_batches):
        batch = next(bm)
        train_loss, params, opt_state = do_batch_update(batch, params, opt_state, key=next(prng_seq))
        batch_loss += train_loss
    test_loss = vae_loss(params, X_test, next(prng_seq))
    train_losses.append(batch_loss / bm.num_batches)
    test_losses.append(test_loss)

## Training Results Visualization

In [None]:
# Apply moving average filer to losses
window_size = 100
window = np.ones(window_size) / window_size
train_losses_f = np.convolve(train_losses, window, mode='valid')
test_losses_f = np.convolve(test_losses, window, mode='valid')
x_pos = np.arange(window_size // 2, window_size // 2 + train_losses_f.shape[0])

plt.plot(train_losses, label='Train Losses')
plt.plot(test_losses, label='Test Losses')
plt.plot(x_pos, train_losses_f, label='Smoothed Train Losses')
plt.plot(x_pos, test_losses_f, label='Smoothed Test Losses')
plt.xlabel('Epoch')
plt.ylabel('VAE Loss')
plt.legend()
plt.show()

## Sample Generation and Visualization

In [None]:
def sample_model(params: chex.ArrayTree, key: chex.PRNGKey, num_samples: int):
    x, _, _ = model.apply(params, X_train, key)
    return x

In [None]:
samples = sample_model(params, next(prng_seq), 2000)

plt.scatter(samples[:, 0], samples[:, 1], marker='.', label='Sampled')
plt.scatter(X_train[:, 0], X_train[:, 1], alpha=0.2, marker='o', label='Train')
plt.scatter(X_test[:, 0], X_test[:, 1], alpha=0.2, marker='o', label='Test')
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.axis('equal')
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.legend()
plt.show() 

In [None]:
# Save model outputs
save_samples('vae', dataset_name, samples)