The aim of this notebook is to train a simple MLP VAE to generate MNIST-like images.
It's inspired by https://www.youtube.com/watch?v=VELQT1-hILo&t=1707s.

For losses, following the video above, I use 'sum' reduction instead of typical 'mean' because it gives better
results. That's probably because it sets a better balance between reconstruction 
and KL divergence losses.

### Imports & constants

In [1]:
import sys
sys.path.append('../src')

In [2]:
import torch
import torchvision
import numpy as np

from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

from models import VariationalAutoencoder

In [3]:
IMG_SIZE = 28
HIDDEN_DIM = 256
LATENT_DIM = 64
BATCH_SIZE = 128

NUM_EPOCHS = 5
LR = 1e-4

### Preparing stuff for training

In [4]:
dataset = torchvision.datasets.MNIST(
    root='../local',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    drop_last=True
)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VariationalAutoencoder(
    input_dim=IMG_SIZE*IMG_SIZE,
    hidden_dim=HIDDEN_DIM,
    latent_dim=LATENT_DIM
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

### Training

In [6]:
for epoch in range(NUM_EPOCHS):
    classes = []
    latents = []

    tqdm_it = tqdm(dataloader, total=len(dataloader), leave=True)
    tqdm_it.set_description(f'Epoch: [{epoch+1}/{NUM_EPOCHS}]')

    for x, y in tqdm_it:
        x = x.view(BATCH_SIZE, -1).to(device)
        x_reconstr, latent_sampled, mu, logvar = model(x)

        loss_reconstr = loss_bce(x_reconstr, x)
        loss_kl = -1/2 * torch.sum(
            1 + logvar - torch.exp(logvar) - mu.pow(2)
        )
        loss = loss_reconstr + loss_kl 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        classes.append(y.numpy())
        latents.append(latent_sampled.detach().cpu().numpy())
        tqdm_it.set_postfix(loss=loss.item())

    # TODO: some kind of fancy plotting every `n` epochs

Epoch: [1/5]: 100%|██████████| 468/468 [00:29<00:00, 15.77it/s, loss=2.69e+4]
Epoch: [2/5]: 100%|██████████| 468/468 [00:17<00:00, 27.22it/s, loss=2.35e+4]
Epoch: [3/5]: 100%|██████████| 468/468 [00:13<00:00, 34.99it/s, loss=2.17e+4]
Epoch: [4/5]: 100%|██████████| 468/468 [00:13<00:00, 34.16it/s, loss=2.1e+4] 
Epoch: [5/5]: 100%|██████████| 468/468 [00:28<00:00, 16.30it/s, loss=1.96e+4]


### Generating new samples

To generate an example of a desired digit, we'll first calculate `mu` and `logvar`
for a random image of a picked digit, then use them to sample new examples.

In [10]:
DIGIT = 0
NUM_SAMPLES = 3  # Number of generated samples per digits

In [11]:
rand_idx = np.random.choice(
    torch.nonzero(dataset.train_labels == DIGIT).flatten()
)
img = dataset.train_data[rand_idx:rand_idx+1, :, :] / 255
save_image(img, f'{DIGIT}_original.png')

In [12]:
mu, logvar = model.encode(img.view(1, -1).to(device))

for i in range(NUM_SAMPLES):
    eps = torch.randn_like(logvar)
    latent_sampled = mu + torch.sqrt(torch.exp(logvar))*eps
    img_sampled = model.decode(latent_sampled, apply_sigmoid=True)
    save_image(img_sampled.view(1, IMG_SIZE, IMG_SIZE), f'{DIGIT}_sampled_v{i+1}.png')