# VQ-VAE Training

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from livelossplot import PlotLosses
from tqdm import tqdm
from models.foster import VectorQuantizedVariationalAutoencoder

In [None]:
BATCH_SIZE = 512
LEAKY_RELU_NEGATIVE_SLOPE = 0.1
DROPOUT_P = 0.25
Z_DIM = 2
LEARNING_RATE = 0.01
NUM_EPOCHS = 200

NUM_EMBEDDINGS = 20
COMMITMENT_COST = 0.25
DECAY = 0.99

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
initialise_embedding_vectors = "normal" # "uniform" or "normal"

In [None]:
model = VectorQuantizedVariationalAutoencoder(num_embeddings=NUM_EMBEDDINGS,
                                              embedding_dim=Z_DIM,
                                              commitment_cost=COMMITMENT_COST,
                                              decay=DECAY,
                                              leaky_relu_negative_slope=LEAKY_RELU_NEGATIVE_SLOPE,
                                              dropout_p=DROPOUT_P)

In [None]:
initialise_from = "vae" # "none" or "ae" or "vae"

if initialise_from == "ae":
    model.encoder = torch.load('checkpoints/ae_encoder_only.pt')
    model.decoder = torch.load('checkpoints/ae_decoder_only.pt')
if initialise_from == "vae":
    model.encoder = torch.load('checkpoints/vae_mu_encoder_only.pt')
    model.decoder = torch.load('vae_decoder_only.pt')

In [None]:
model = model.to(device)
model(torch.rand(BATCH_SIZE, 1, 28, 28).to(device)) # test forward pass
model.summary(input_data=torch.rand((BATCH_SIZE, 1, 28, 28))) # summarize

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
liveloss = PlotLosses()
logs = {}
for epoch in tqdm(range(NUM_EPOCHS)):
    for phase in ['train', 'validation']:
        model.train() if phase == 'train' else model.eval()
        running_vq_loss = 0.0
        running_recon_loss = 0.0
        running_loss = 0.0
        running_perplexity = 0.0
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            outputs, vq_loss, perplexity = model(inputs)
            recon_loss = criterion(inputs, outputs)
            loss = recon_loss + vq_loss
            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            running_recon_loss += recon_loss.detach() * inputs.size(0)
            running_vq_loss += vq_loss.detach() * inputs.size(0)
            running_loss += loss.detach() * inputs.size(0)
            running_perplexity += perplexity.detach() * inputs.size(0)
        epoch_recon_loss = running_recon_loss / len(dataloaders[phase].dataset)
        epoch_vq_loss = running_vq_loss / len(dataloaders[phase].dataset)
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_perplexity = running_perplexity / len(dataloaders[phase].dataset)
        prefix = 'val_' if phase == 'validation' else ''
        logs[prefix + 'recon_loss'] = epoch_recon_loss.item()
        logs[prefix + 'vq_loss'] = epoch_vq_loss.item()
        logs[prefix + 'loss'] = epoch_loss.item()
        logs[prefix + 'perplexity'] = epoch_perplexity.item()
    if epoch > -1:
        liveloss.update(logs)
        liveloss.send()
    torch.save(model, f'checkpoints/vqvae_from_{initialise_from}_using_{initialise_embedding_vectors}_epoch_{epoch}.pt')
torch.save(model, f'checkpoints/vqvae_from_{initialise_from}_using_{initialise_embedding_vectors}.pt')