# Variational Autoencoders

In [None]:
import subprocess

subprocess.run(["npm", "i", "--no-progress"], cwd="../widgets", check=True)
subprocess.run(["npm", "i", "--no-progress"], cwd="widget-wrappers", check=True)
subprocess.run(["bash", "build_wrapped_widgets.sh"], cwd="widget-wrappers", check=True)

In [None]:
import numpy as np
import torch
from tqdm.notebook import trange

from constants import hue_range, latent_dim, sidelength, size_range
from dataset import generate_dataset
from elbo import approximate_elbo
from model import VAE
from util import BatchIterator, onnx_export, onnx_export_to_files, plot_losses
from vaewidgets import (
    AreaSelectionWidget,
    dataset_explanation,
    dataset_visualization,
    decoding,
    mapping,
)

## Dataset explanation

In [None]:
dataset_explanation()

## Train/validation set split

In [None]:
valset_selection = AreaSelectionWidget(size_range, hue_range, "Size", "Hue", 0.6, 0.4, 0.3, 0.3)
valset_selection

In [None]:
trainset_coords, valset_coords, trainset, valset = generate_dataset(
    size_range=size_range,
    hue_range=hue_range,
    valset_size_range=(valset_selection.x, valset_selection.x + valset_selection.width),
    valset_hue_range=(valset_selection.y, valset_selection.y + valset_selection.height),
    num_samples=2000,
)

In [None]:
dataset_visualization(trainset_coords, valset_coords, trainset, valset, True)

## Training

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

vae = VAE(latent_dim=2).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

train_losses: list[float] = []
val_losses: list[float] = []
best_val_loss: float = np.inf

batch_size = 256
num_epochs = 100
pbar = trange(num_epochs)
for epoch in pbar:
    vae.train()
    per_batch_train_losses = []
    batch_iterator = BatchIterator(trainset, batch_size)
    for batch in batch_iterator:
        x = (batch / 255.0).to(device)
        mu_x, logvar_x, _, mu_z = vae(x)
        loss: torch.Tensor = -approximate_elbo(
            x.view(x.shape[0], sidelength * sidelength * 3),
            mu_z.view(mu_z.shape[0], sidelength * sidelength * 3),
            mu_x,
            logvar_x,
            sigma2=1.0,
        ).mean()
        per_batch_train_losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()  # type: ignore[no-untyped-call]
        optimizer.step()
    train_losses.append(float(np.mean(per_batch_train_losses)))

    per_batch_val_losses = []
    vae.eval()
    with torch.no_grad():
        batch_iterator = BatchIterator(valset, batch_size)
        for batch in batch_iterator:
            x = (batch / 255.0).to(device)
            mu_x, logvar_x, _, mu_z = vae(x)
            loss = -approximate_elbo(
                x.view(x.shape[0], sidelength * sidelength * 3),
                mu_z.view(mu_z.shape[0], sidelength * sidelength * 3),
                mu_x,
                logvar_x,
                sigma2=1.0,
            ).mean()
            per_batch_val_losses.append(loss.item())
    pbar.set_description(
        f"Train Loss: {train_losses[-1]:.4f}, Val Loss: {np.mean(per_batch_val_losses):.4f}"
    )
    epoch_val_loss = float(np.mean(per_batch_val_losses))
    val_losses.append(float(epoch_val_loss))

    if epoch > float(num_epochs) * 0.75 and epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(vae.state_dict(), "vae.pth")

plot_losses(train_losses, val_losses)

In [None]:
# Load the best model again and export to ONNX so we can use it in the browser
vae = VAE(latent_dim=latent_dim)
vae.load_state_dict(torch.load("vae.pth"))
vae.eval()

encoder, decoder = onnx_export(vae.encoder, vae.decoder)

# TODO Remove
onnx_export_to_files(vae.encoder, vae.decoder, "encoder.onnx", "decoder.onnx")

In [None]:
mapping(
    encoder,
    decoder,
    (
        (valset_selection.x, valset_selection.x + valset_selection.width),
        (valset_selection.y, valset_selection.y + valset_selection.height),
    ),
)

In [None]:
decoding(encoder, decoder)