# Variational Autoencoders

In [None]:
import subprocess

subprocess.run(["npm", "i", "--no-progress"], cwd="../widgets", check=True)
subprocess.run(["npm", "i", "--no-progress"], cwd="widget-wrappers", check=True)
subprocess.run(["bash", "build_wrapped_widgets.sh"], cwd="widget-wrappers", check=True)

In [None]:
import math

import torch
import numpy as np
import torch.nn as nn
from tqdm.notebook import trange
import matplotlib.pyplot as plt

from vaewidgets import *
from dataset import generate_dataset
from util import map_tuple, BatchIterator, plot_losses, onnx_export
from constants import size_range, hue_range, sidelength, latent_dim

## Loss function

In [None]:
def kl_divergence(mu, logvar):
    """
    KL divergence between N(mu, sigma^2) and N(0, 1), per sample.
    mu, logvar: tensors of shape [batch_size, latent_dim]
    Returns: tensor of shape [batch_size]
    """
    return 0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1 - logvar, dim=1)


def log_normal_diag_spherical(x, mu, sigma2):
    """
    Computes log N(x; mu, sigma^2 I) for image tensors.

    Args:
        x: [B, C, H, W]
        mu: [B, C, H, W]
        sigma2: scalar (fixed variance)

    Returns:
        log_probs: [B] — per-sample log likelihoods
    """
    B = x.size(0)
    d = x[0].numel()  # total dims per sample
    squared_error = (x - mu).pow(2).view(B, -1).sum(dim=1)

    const_term = -0.5 * d * math.log(2 * math.pi)
    log_sigma_term = -0.5 * d * math.log(sigma2)
    quad_term = -0.5 / sigma2 * squared_error
    return const_term + log_sigma_term + quad_term


def approximate_elbo(xi, mu_z, mu_xi, logvar_xi, sigma2):
    """
    Approximates ELBO for each data point xi.

    Args:
        xi: [batch_size, input_dim] — true input data
        mu_z: [batch_size, input_dim] — decoder output (mean of p(x|z))
        mu_xi, logvar_xi: [batch_size, latent_dim] — encoder outputs
        sigma2: scalar — reconstruction variance

    Returns:
        elbo: [batch_size] — per-sample ELBO
    """
    assert len(xi.shape) == 2, "xi must be a 2D tensor"
    assert len(mu_z.shape) == 2, "mu_z must be a 2D tensor"
    assert len(mu_xi.shape) == 2, "mu_xi must be a 2D tensor"
    assert len(logvar_xi.shape) == 2, "logvar_xi must be a 2D tensor"

    assert xi.shape == mu_z.shape, "xi and mu_z must have the same shape"
    assert mu_xi.shape == logvar_xi.shape, "mu_xi and logvar_xi must have the same shape"
    assert sigma2 > 0, "sigma2 must be positive"

    recon_term = log_normal_diag_spherical(xi, mu_z, sigma2)
    kl_term = kl_divergence(mu_xi, logvar_xi)
    beta = 1.0  # default beta value, can be adjusted
    # return recon_term - kl_term
    return recon_term - beta * kl_term  # ELBO = log p(x|z) - KL(q(z|x) || p(z))

## Dataset explanation

In [None]:
dataset_explanation()

## Train/validation set split

In [None]:
valset_selection = AreaSelectionWidget(size_range, hue_range, "Size", "Hue", 0.6, 0.4, 0.3, 0.3)
valset_selection

In [None]:
trainset_coords, valset_coords, trainset, valset = generate_dataset(
    size_range=size_range,
    hue_range=hue_range,
    valset_size_range=(valset_selection.x, valset_selection.x + valset_selection.width),
    valset_hue_range=(valset_selection.y, valset_selection.y + valset_selection.height),
    num_samples=2000,
)

In [None]:
dataset_visualization(trainset_coords, valset_coords, trainset, valset)

## Model

In [None]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


class Encoder(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        h = self.conv(x)
        return self.fc_mu(h), self.fc_logvar(h)


class Decoder(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        h = self.fc(z).view(-1, 128, 4, 4)
        return self.deconv(h)


class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        mu_x, logvar_x = self.encoder(x)
        z = reparameterize(mu_x, logvar_x)
        mu_z = self.decoder(z)
        return mu_x, logvar_x, z, mu_z

## Training

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

vae = VAE(latent_dim=2).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

train_losses = []
val_losses = []
val_mses = []
best_val_loss = float("inf")

batch_size = 256
num_epochs = 100
pbar = trange(num_epochs)
for epoch in pbar:
    vae.train()
    per_batch_train_losses = []
    batch_iterator = BatchIterator(trainset, batch_size)
    for batch in batch_iterator:
        x = (batch / 255.0).to(device)
        mu_x, logvar_x, _, mu_z = vae(x)
        loss = -approximate_elbo(
            x.view(x.shape[0], sidelength * sidelength * 3),
            mu_z.view(mu_z.shape[0], sidelength * sidelength * 3),
            mu_x,
            logvar_x,
            sigma2=1.0,
        ).mean()
        per_batch_train_losses.append(loss.item())
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
    train_losses.append(np.mean(per_batch_train_losses))

    per_batch_val_losses = []
    vae.eval()
    with torch.no_grad():
        batch_iterator = BatchIterator(valset, batch_size)
        for batch in batch_iterator:
            x = (batch / 255.0).to(device)
            mu_x, logvar_x, _, mu_z = vae(x)
            loss = -approximate_elbo(
                x.view(x.shape[0], sidelength * sidelength * 3),
                mu_z.view(mu_z.shape[0], sidelength * sidelength * 3),
                mu_x,
                logvar_x,
                sigma2=1.0,
            ).mean()
            per_batch_val_losses.append(loss.item())
    pbar.set_description(
        f"Train Loss: {train_losses[-1]:.4f}, Val Loss: {np.mean(per_batch_val_losses):.4f}"
    )
    epoch_val_loss = np.mean(per_batch_val_losses)
    val_losses.append(epoch_val_loss)

    if epoch > float(num_epochs) * 0.75 and epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(vae.state_dict(), "vae.pth")

plot_losses(train_losses, val_losses)

In [None]:
# Load the best model again and export to ONNX so we can use it in the browser
vae = VAE(latent_dim=latent_dim)
vae.load_state_dict(torch.load("vae.pth"))
vae.eval()
encoder, decoder = onnx_export(vae)

In [None]:
mapping(
    encoder,
    decoder,
    [
        [valset_selection.x, valset_selection.x + valset_selection.width],
        [valset_selection.y, valset_selection.y + valset_selection.height],
    ],
)

In [None]:
decoding(encoder, decoder)