## Variational Autoencoder

This notebook presents a model that can be used for data augmentation. Variational autoencoder can learn multivariate latent distribution of the input data to further reconstruct it as accurately as possible. In order to build a better organized latent space there was implemented [Cyclical Annealing Schedule](https://aclanthology.org/N19-1021.pdf).

### 0. Prerequisites

In [1]:
import os 
import time
import torch
import wandb
import optuna
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy import stats
from tqdm import tqdm
from dotenv import load_dotenv
from torch import nn
from torch.optim import lr_scheduler
from collections import OrderedDict

from src.metrics import pearson_metric
from src.data import Dataset, load_data, build_torch_dataloaders

In [2]:
load_dotenv()

True

In [3]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mubiquant-experiments[0m (use `wandb login --relogin` to force relogin)


True

In [4]:
DEVICE = "cuda:0"
PROJECT = "VAE"
ENTITY = "parmezano"
EXPERIMENT = "baseline + beta scheduling"

In [5]:
config = {
    "Epochs": 50,
    "Latent vector dim.": 10,
    "Layers": [302, 256, 128, 64, 32],
    "Optimizer": "AdamW",
    "Learning rate": 1e-3,
    "Weight decay": 5e-4,
    "Batch size": 25000,
    "Scheduler": "CosineAnnealingLR",
    "Split data": True,    
}

In [6]:
wandb.init(project=PROJECT, entity=ENTITY, name=EXPERIMENT, config=config)

[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [7]:
weights_dir = os.path.join("weights", PROJECT)
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)

model_dir = os.path.join("models", PROJECT)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

### 1. Data preparation

In [8]:
df = load_data(use_feather=True)

Loading took 10.42 seconds


In [9]:
trainloader, testloader = build_torch_dataloaders(df, split_data=config['Split data'])

### 2. Building a Model.

In [10]:
class BlockAE(nn.Module):
    def __init__(self, input_dim, output_dim, activation):
        super(BlockAE, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            activation()
        )
        
    def forward(self, x):
        out = self.main(x)
        return out

In [11]:
class VAE(nn.Module):
    
    def __init__(self, layers, dim, activation=nn.ReLU):
        super(VAE, self).__init__()
        self.layers = layers
        self.dim = dim
        self.activation = activation
        
        self.encoder = self.__build_encoder()
        self.decoder = self.__build_decoder()

    def __build_encoder(self):
        layers = self.layers
        activation = self.activation
        encoder_layers = []
        for idx in range(len(layers) - 1):
            layer = BlockAE(layers[idx], layers[idx + 1], activation)
            encoder_layers.append(layer)
        encoder_layers.append(nn.Linear(layers[-1], self.dim * 2))
        return torch.nn.Sequential(*encoder_layers)
    
    def __build_decoder(self):
        layers = self.layers[::-1]
        activation = self.activation
        decoder_layers = [nn.Linear(self.dim, layers[0])]
        for idx in range(len(layers) - 2):
            layer = BlockAE(layers[idx], layers[idx + 1], activation)
            decoder_layers.append(layer)
        decoder_layers.append(nn.Linear(layers[-2], layers[-1]))
        return torch.nn.Sequential(*decoder_layers)
            
    
    def reparametrize(self, mu, sigma):
        if self.training:
            std = sigma.mul(0.5).exp_()
            eps = std.new_empty(std.size()).normal_()
            return eps.mul_(std).add_(mu)
        return mu
    
    def forward(self, x):
        bottleneck = self.encoder(x).view(-1, 2, self.dim)
        mu, sigma = bottleneck[:, 0, :], bottleneck[:, 1, :]
        z = self.reparametrize(mu, sigma)
        out = self.decoder(z)
        return out, mu, sigma


In [12]:
def compute_beta(epoch):
    return min((epoch % 50) / 40, 1.)

In [13]:
def vae_loss(x_hat, x, mu, sigma, beta=1):
    bce = nn.L1Loss()(x_hat, x)
    kld = 0.5 * torch.sum(sigma.exp() - sigma - 1 + mu.pow(2))
    return bce + beta * kld

### 3. Training the Model.

In [14]:
def make():
    model = VAE(layers=config["Layers"], dim=config['Latent vector dim.'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['Learning rate'], weight_decay=config['Weight decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['Epochs'])
    criterion = vae_loss
    return model, optimizer, scheduler, criterion

In [15]:
def train(model, criterion, loader, optimizer, device='cpu'):
    model.to(device)
    model.train()
    train_loss = 0.0
    for i, (_x, _y) in enumerate(loader):
        optimizer.zero_grad()
        x = torch.cat((_x, torch.unsqueeze(_y, 1)), dim=1)
        x = x.to(device)
        x_hat, mu, sigma = model(x)
        beta = compute_beta(i)
        loss = criterion(x_hat, x, mu, sigma, beta=beta)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    return train_loss / len(loader.dataset)

In [16]:
def test(model, criterion, loader, device='cpu'):
    model.to(device)
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for i, (_x, _y) in enumerate(loader):
            x = torch.cat((_x, torch.unsqueeze(_y, 1)), dim=1)
            x = x.to(device)
            x_hat, mu, sigma = model(x)
            loss = criterion(x_hat, x, mu, sigma)
            test_loss += loss.item()
    return test_loss / len(loader.dataset)

In [17]:
def export_onnx(model):
    dummy_input = torch.randn(302, device="cpu")
    torch.onnx.export(model.cpu(), dummy_input, f"{model_dir}.onnx")
    wandb.save(f"{model_dir}.onnx")

In [18]:
def save_weights(test_loss, losses, epoch, model, optimizer):
    losses.append(test_loss)
    if test_loss <= min(losses):
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss, 
        }, os.path.join(weights_dir, f"{epoch}.pt"))
    return losses

In [19]:
def pipeline(trainloader, testloader, log=False):
    model, optimizer, scheduler, criterion = make()
    wandb.watch(model, criterion, log="all", log_freq=10)
    losses = []
    
    for epoch in range(config['Epochs']):
        train_loss = train(model, criterion, trainloader, optimizer, device=DEVICE)
        test_loss = test(model, criterion, testloader, device=DEVICE)
        scheduler.step()

        losses = save_weights(test_loss, losses, epoch, model, optimizer)

        wandb.log({"Train loss": train_loss, "Test loss": test_loss})
        if log and epoch % 5 == 0:
            print(f"Epoch: {epoch+1:02d} | Train: {train_loss:.8f} | Test: {test_loss:.8f}")
    
    export_onnx(model)

In [20]:
pipeline(trainloader, testloader, log=True)

Epoch: 01 | Train: 0.02634314 | Test: 0.00764904
Epoch: 06 | Train: 0.00025144 | Test: 0.00039427
Epoch: 11 | Train: 0.00019361 | Test: 0.00026698
Epoch: 16 | Train: 0.00017735 | Test: 0.00018584
Epoch: 21 | Train: 0.00017690 | Test: 0.00016340
Epoch: 26 | Train: 0.00015494 | Test: 0.00016623
Epoch: 31 | Train: 0.00015473 | Test: 0.00017369
Epoch: 36 | Train: 0.00015392 | Test: 0.00015936
Epoch: 41 | Train: 0.00015407 | Test: 0.00015964
Epoch: 46 | Train: 0.00015405 | Test: 0.00015971
Epoch: 51 | Train: 0.00015395 | Test: 0.00015963
Epoch: 56 | Train: 0.00015390 | Test: 0.00015946
Epoch: 61 | Train: 0.00015375 | Test: 0.00015929
Epoch: 66 | Train: 0.00015375 | Test: 0.00015939
Epoch: 71 | Train: 0.00015370 | Test: 0.00015938
Epoch: 76 | Train: 0.00015366 | Test: 0.00015913
Epoch: 81 | Train: 0.00015358 | Test: 0.00015909
Epoch: 86 | Train: 0.00015354 | Test: 0.00015923
Epoch: 91 | Train: 0.00015352 | Test: 0.00015935
Epoch: 96 | Train: 0.00015346 | Test: 0.00015898
Epoch: 101 | Train: 

### 4. Inference

In [21]:
checkpoint = torch.load(os.path.join(weights_dir, "49.pt"))
inferenced, _, _, _ = make()
inferenced.load_state_dict(checkpoint['model_state_dict'])
inferenced.eval()

VAE(
  (encoder): Sequential(
    (0): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=302, out_features=256, bias=True)
        (1): ReLU()
      )
    )
    (1): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): ReLU()
      )
    )
    (2): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=128, out_features=64, bias=True)
        (1): ReLU()
      )
    )
    (3): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=64, out_features=32, bias=True)
        (1): ReLU()
      )
    )
    (4): Linear(in_features=32, out_features=20, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=32, out_features=64, bias=True)
        (1): ReLU()
      )
    )
    (2): BlockAE(
      (main): Sequential(
        (0): Linear(in_features=64, out_features=128, bias=Tru

In [22]:
z = torch.randn((100, 10))
sample = inferenced.decoder(z)
print(torch.max(torch.abs(sample[:, -1])))

tensor(0.1476, grad_fn=<MaxBackward1>)
