# Variational Autoencoders

In [None]:
# Clone the repo if needed (e.g. on Colab)
import os
import subprocess


def is_correct_repo() -> bool:
    try:
        result = subprocess.run(
            ["git", "remote", "get-url", "origin"], capture_output=True, text=True, check=True
        )
        remote_url = result.stdout.strip()
        return remote_url in [
            "https://github.com/mariogemoll/vae.git",
            "git@github.com:mariogemoll/vae.git",
        ]
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


if not is_correct_repo():
    !git clone https://github.com/mariogemoll/vae.git

if not os.getcwd().endswith("vae/notebooks"):
    %cd vae/notebooks

In [None]:
# Install Python dependencies
%pip install -r requirements-build.txt -q

In [None]:
# Build UI widgets
subprocess.run(["npm", "i", "--no-progress"], cwd="../widgets", check=True)
subprocess.run(["npm", "i", "--no-progress"], cwd="widget-wrappers", check=True)
subprocess.run(["bash", "build_wrapped_widgets.sh"], cwd="widget-wrappers", check=True)

In [None]:
import torch

from constants import hue_range, latent_dim, sidelength, size_range
from dataset import generate_dataset
from grid import make_standard_grid
from image import get_images
from model import VAE
from training import train
from util import get_device, onnx_export, plot_losses
from vaewidgets import (
    AreaSelectionWidget,
    dataset_explanation,
    dataset_visualization,
    decoding,
    evolution,
    mapping,
    sampling,
)

device = get_device()
if device.type == "cpu":
    print("Using CPU for training. You might want to switch to a GPU to speed things up!")

## Dataset explanation

In [None]:
dataset_explanation()

## Train/validation set split

In [None]:
valset_selection = AreaSelectionWidget(size_range, hue_range, "Size", "Hue", 0.6, 0.4, 0.3, 0.3)
valset_selection

In [None]:
trainset_coords, valset_coords, trainset, valset = generate_dataset(
    size_range=size_range,
    hue_range=hue_range,
    valset_size_range=(valset_selection.x, valset_selection.x + valset_selection.width),
    valset_hue_range=(valset_selection.y, valset_selection.y + valset_selection.height),
    num_samples=2000,
)

In [None]:
dataset_visualization(trainset_coords, valset_coords, trainset, valset, True)

## Training

In [None]:
device = get_device()
standard_grid = make_standard_grid(size_range, hue_range)
imgs = get_images(sidelength, [tuple(pair) for pair in standard_grid.reshape(-1, 2).tolist()])
grid_x = (torch.from_numpy(imgs).float() / 255.0).to(device)
train_losses, val_losses, grids = train(device, trainset, valset, "vae.pth", 100, 256, 64, grid_x)
plot_losses(train_losses, val_losses)

In [None]:
# Load the best model again and export to ONNX so we can use it in the browser
vae = VAE(latent_dim=latent_dim)
vae.load_state_dict(torch.load("vae.pth"))
vae.eval()
encoder, decoder = onnx_export(vae.encoder, vae.decoder)

In [None]:
sampling(decoder)

In [None]:
assert grids is not None
evolution(train_losses, val_losses, grids.numpy())

In [None]:
mapping(
    encoder,
    decoder,
    (
        (valset_selection.x, valset_selection.x + valset_selection.width),
        (valset_selection.y, valset_selection.y + valset_selection.height),
    ),
)

In [None]:
decoding(encoder, decoder)