# CelebA experiment for VAE models

## Imports

In [None]:
from models import vae_hyp_celeba, vae_eucl_celeba
import geoopt
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from hypmath import poincareball
from hypmath import metrics
from tqdm import tqdm
import time
from torch.utils.data import SubsetRandomSampler

#Disable Debugging APIs
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

#cuDNN Autotuner
torch.backends.cudnn.benchmark = True

## CUDA check

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Training, validation and test data

In [None]:
image_size = 64
transform = transforms.Compose([
                        transforms.Resize(image_size),
                        transforms.CenterCrop(image_size),
                        transforms.ToTensor(),
                        ])

#CelebA data can be downloaded at https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
trainset = datasets.ImageFolder('data', transform=transform)
num_data = list(range(0, 102400))
trainset_1 = torch.utils.data.Subset(trainset, num_data)
size= len(trainset_1)
train_data, val_data = torch.utils.data.random_split(trainset_1, [int(size-size*0.2), int(size*0.2)])
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, 
                                         num_workers=1, pin_memory=True, shuffle=True)
valloader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=True, num_workers=1, pin_memory=True)


In [None]:
reconstruction_function = nn.MSELoss(reduction='sum')
def loss_function(recon_x, x, mu, logvar):

    """
    Loss function for VAE:
    reconstruction term + regularization term
    """
    MSE = reconstruction_function(recon_x, x)

    # https://arxiv.org/abs/1312.6114 (Appendix B)
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return MSE + KLD

def train_epoch(vae, dataloader, optimizer):
    """
    Model training function
    """

    # Set train mode for both the encoder and the decoder
    vae.train()
    train_loss = 0.0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for x, _ in dataloader: 
        # Move tensor to the proper device
        x = x.to(device)
        for param in vae.parameters():
            param.grad = None

        recon_x, mu, logvar = vae(x)
        # Evaluate loss
        loss = loss_function(recon_x, x, mu, logvar)

        # Backward pass
        #optimizer.zero_grad()   
        loss.backward()
        optimizer.step()
        # Print batch loss
        print('\t partial train loss (single batch): %f' % (loss.item()))
        train_loss+=loss.item()

    return train_loss / len(dataloader.dataset)

In [None]:
def test_epoch(vae, dataloader):
    """
    Model validation function
    """
    
    # Set evaluation mode for encoder and decoder
    vae.eval()
    val_loss = 0.0
    with torch.no_grad(): # No need to track the gradients
        for x, _ in dataloader:
            # Move tensor to the proper device
            x = x.to(device)
            recon_x, mu, logvar = vae(x)
            # Evaluate loss
            loss = loss_function(recon_x, x, mu, logvar)
            val_loss += loss.item()

    return val_loss / len(dataloader.dataset)

In [None]:
def plot_ae_outputs(encoder,decoder,n):
    """
    Plots the reconstructed images from VAE
    """
    
    plt.figure(figsize=(10,4.5))
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = trainset_1[i][0].unsqueeze(0)
      img = img.to(device)
      #img = next(iter(trainloader))
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
        z, _ , _ = encoder(img)
        rec_img  = decoder(z)
      plt.imshow(img.cpu().squeeze().permute(1, 2, 0).numpy())
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow((rec_img.cpu().squeeze().permute(1, 2, 0).numpy()))  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

## Initialize and train model

In [None]:
model = vae_eucl_celeba.VariationalAutoencoder(nc=3, ndf=64, ngf=64, latent_dims=500, device=device)
#model = vae_hyp_celeba.VariationalAutoencoder(nc=3, ndf=64, ngf=64, latent_dims=100, device=device)
model.to(device)
print(model)
epochs = 5
lr = 5e-4
#lr = 0.01
#optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


t_loss = []
v_loss = []
epoch_values =[]
#Training loop
for epoch in range(epochs):
   train_loss = train_epoch(model, trainloader, optimizer)
   val_loss = test_epoch(model, valloader)
   t_loss.append(train_loss)
   v_loss.append(val_loss)
   epoch_values.append(epoch)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t validation loss {:.3f}'.format(epoch + 1, epochs, train_loss, val_loss))
   plot_ae_outputs(model.encoder, model.decoder,n=4)


# save model checkpoint
# torch.save({
#             'epoch': epochs,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss_function,
#             }, 'outputs/resnet_model.pth')

## Curve Plotting

In [None]:
fig , (ax0) = plt.subplots(1, 1)
ax0.set_title('Loss Curves')
ax0.plot(epoch_values, t_loss, 'bo-', label='train')
ax0.plot(epoch_values, v_loss, 'ro-', label='val')

ax0.set_xlabel('Epochs')
ax0.set_ylabel('Losses')
ax0.legend()

fig.suptitle('no. of epochs = {}, lr = {}, batch size = 64'.format(epochs, lr))
fig.tight_layout()

## Resuming Model Training

In [None]:
# # load the trained model
# model_resume = model.to(device) # initilize the model
# # initialize optimizer  before loading optimizer state_dict
# epochs_new = 5
# learning_rate_new = 5e-4
# optimizer_new = optim.Adam(model_resume.parameters(), lr=learning_rate_new)


# checkpoint = torch.load('outputs/model.pth')

# # load model weights state_dict
# model.load_state_dict(checkpoint['model_state_dict'])
# print('Previously trained model weights state_dict loaded...')

# # load trained optimizer state_dict
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# print('Previously trained optimizer state_dict loaded...')

# epochs = checkpoint['epoch']
# # load the criterion
# loss_function = checkpoint['loss']
# print('Trained model loss function loaded...')
# print(f"Previously trained for {epochs} number of epochs...")

# # train for more epochs
# epochs = epochs_new
# print(f"Train for {epochs} more epochs...")


# #New Training loop
# for epoch in range(epochs):
#    train_loss = train_epoch(model_resume, trainloader, optimizer_new)
#    print('\n EPOCH {}/{} \t train loss {:.3f}'.format(epoch + 1, epochs, train_loss))
#    plot_ae_outputs(model.encoder, model.decoder,n=4)

# # save model checkpoint
# torch.save({
#             'epoch': epochs,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss_function,
#             }, 'outputs/model.pth')

#5,3,3,3,5