# Variational AutoEncoder with JAX

## Setup

In [17]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
import matplotlib.pyplot as plt
from flax.training import train_state
import tensorflow_datasets as tfds

In [18]:
# Load and preprocess the MNIST dataset
def load_mnist():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    ds = ds_builder.as_dataset(split='train')
    mnist_digits = np.array([np.array(x['image'], dtype=np.float32) / 255.0 for x in ds])
    return mnist_digits.reshape(-1, 28, 28, 1)

## Create a sampling layer

In [19]:
class Sampling(nn.Module):
    """Uses (z_mean, z_log_var) to sample z."""

    @nn.compact
    def __call__(self, z_mean, z_log_var):
        batch_size = z_mean.shape[0]
        dim = z_mean.shape[1]
        epsilon = jax.random.normal(jax.random.PRNGKey(0), (batch_size, dim))  # Use a different key each time
        return z_mean + jnp.exp(0.5 * z_log_var) * epsilon


## Build the encoder

In [32]:
class Encoder(nn.Module):
    latent_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.relu(nn.Conv(32, (3, 3), strides=(2, 2), padding='SAME')(x))
        x = nn.relu(nn.Conv(64, (3, 3), strides=(2, 2), padding='SAME')(x))
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.relu(nn.Dense(16)(x))
        z_mean = nn.Dense(self.latent_dim)(x)
        z_log_var = nn.Dense(self.latent_dim)(x)
        z = Sampling()(z_mean, z_log_var)  # Sample z here
        return z_mean, z_log_var, z

## Build the decoder

In [38]:
class Decoder(nn.Module):
    latent_dim: int

    @nn.compact
    def __call__(self, z):
        x = nn.relu(nn.Dense(7 * 7 * 64)(z))
        x = x.reshape((x.shape[0], 7, 7, 64))  # Reshape back
        x = nn.relu(nn.ConvTranspose(64, (3, 3), strides=(2, 2), padding='SAME')(x))
        x = nn.relu(nn.ConvTranspose(32, (3, 3), strides=(2, 2), padding='SAME')(x))
        return nn.sigmoid(nn.ConvTranspose(1, (3, 3), padding='SAME')(x))

## Define the VAE as a `Model` with a custom `train_step`

In [39]:
class VAE(nn.Module):
    latent_dim: int

    @nn.compact
    def __call__(self, x):
        z_mean, z_log_var, z = Encoder(self.latent_dim)(x)
        reconstruction = Decoder(self.latent_dim)(z)
        return z_mean, z_log_var, reconstruction


In [48]:
# # Loss Function
# def vae_loss(recon_x, x, z_mean, z_log_var):
#     reconstruction_loss = jnp.sum(
#         jax.nn.binary_cross_entropy(recon_x, x, from_logits=False), axis=(1, 2) # AttributeError: module 'jax.nn' has no attribute 'binary_cross_entropy'
#     )
#     kl_loss = -0.5 * jnp.sum(1 + z_log_var - jnp.square(z_mean) - jnp.exp(z_log_var), axis=1)
#     return jnp.mean(reconstruction_loss + kl_loss)


# Loss Function
def vae_loss(recon_x, x, z_mean, z_log_var):
    recon_x = recon_x.reshape(x.shape)  # Reshape recon_x to match x
    # Custom binary cross-entropy
    reconstruction_loss = -jnp.sum(x * jnp.log(recon_x + 1e-10) + (1 - x) * jnp.log(1 - recon_x + 1e-10), axis=(1, 2))
    kl_loss = -0.5 * jnp.sum(1 + z_log_var - jnp.square(z_mean) - jnp.exp(z_log_var), axis=1)
    return jnp.mean(reconstruction_loss + kl_loss)



## Train the VAE

In [49]:
# Training Step
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        recon_x, z_mean, z_log_var = state.apply_fn({'params': params}, batch)
        return vae_loss(recon_x, batch, z_mean, z_log_var)

    grads = jax.grad(loss_fn)(state.params['params'])  # Accessing the correct level of parameters
    state = state.apply_gradients(grads=grads)
    return state

In [50]:
# Initialize and Train
def train_vae():
    dataset = load_mnist()
    model = VAE(latent_dim=2)
    rng = jax.random.PRNGKey(0)
    init_params = model.init(rng, jnp.ones((1, 28, 28, 1)))  # Input shape for MNIST
    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=init_params,
        tx=optax.adam(learning_rate=1e-3),
    )

    # Training Loop
    for epoch in range(30):  # Train for 30 epochs
        for i in range(0, len(dataset), 128):  # Assuming batch size of 128
            batch = dataset[i:i + 128]
            state = train_step(state, batch)

if __name__ == "__main__":
    train_vae()

TypeError: cannot reshape array of shape (128, 2) (size 256) into shape (128, 28, 28, 1) (size 100352)

## Display a grid of sampled digits

In [8]:
# Plot Functions
def plot_latent_space(vae, n=30, figsize=15):
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size:(i + 1) * digit_size, j * digit_size:(j + 1) * digit_size] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

## Display how the latent space clusters different digit classes

In [51]:
def plot_label_clusters(vae, data, labels):
    z_mean, _, _ = vae.encoder(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()