In [20]:
!pip install torch torchvision torchaudio matplotlib numpy

[0m

In [21]:
!ls
!pwd

__notebook_source__.ipynb  data_mnist	       logging_info
checkpoints		   figures_experiment  state.db
/kaggle/working


In [23]:
import os
import utils
from time import time
import logging
import sys

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
import random as random
import os
from vaeee import VAE

from plot_figures import *

In [None]:
# only needs to do this once, comment this out when it's done. It unzips the project into your google drive folder named dd2434-vae
# !unzip "/content/drive/My Drive/dd2434-vae/main.zip" -d "/content/drive/My Drive/dd2434-vae"

In [5]:
sys.platform

'linux'

In [24]:
# if you're not using gpu to run the machine then it nvidia-smi won't work
!nvidia-smi
!python --version

Sun Jan  1 08:24:08 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    35W / 250W |    769MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [25]:
# Device
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# paths
PATH_TRAINING = "checkpoints"
DATA_PATH = "data_mnist"
LOGGING_PATH = "logging_info"
if not os.path.exists(PATH_TRAINING):
    os.makedirs(PATH_TRAINING)
if not os.path.exists(LOGGING_PATH):
    os.makedirs(LOGGING_PATH)


Device: cuda:0


In [28]:
def train(model, optimizer, loss_fn, train_loader, test_loader, params, train_settings, state):
    # number of steps or epoch not given in the paper so have to experiment with
    # 60k training data from MNIST, 100 mini batches
    training_losses = []
    datapoint_training_losses = []
    test_losses = []
    datapoint_test_losses = []

    if train_settings.load_checkpoint:
        datapoint_training_losses = state['datapoint_training_losses']
        training_losses = state['training_losses']
        datapoint_test_losses = state['datapoint_test_losses']
        test_losses = state['test_losses']

    logging.info("Start training...")
    for epoch in range(train_settings.start_epoch, params.NUM_EPOCHS+1):
        model.train()
        running_loss = 0

        for i, (x, _) in enumerate(train_loader):
            # Forward and back prop
            x = x.to(DEVICE)
            _, z_mean, z_log_var, x_mean, x_log_var, reconstructed_x = model(x)

            loss = calculate_loss(z_mean, z_log_var, reconstructed_x, x, loss_fn)
            optimizer.zero_grad()
            loss.backward()
            # update model parameters
            optimizer.step()
            running_loss += loss.item()

            # save model checkpoints
            if epoch % train_settings.save_rate_epoch == 0 and i % train_settings.save_rate_iter == 0 and train_settings.save_checkpoint:
                state = {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'datapoint_training_losses': datapoint_training_losses,
                    'training_losses': training_losses,
                    'datapoint_test_losses': datapoint_test_losses,
                    'test_losses': test_losses,
                    # the loss may not be needed since we start over from a new epoch when resuming
                }
                # logging.info(f"Saving model checkpoint at epoch: {epoch}, iter: {i}")
                # logging.info(f"state_dict: {model.state_dict()}, optimizer: {optimizer.state_dict()}, loss: {loss}")
                torch.save(state, os.path.join(PATH_TRAINING, f'checkpoint_epoch_{epoch}_iter_{i}.pt'))

            if i % train_settings.track_rate == 0:
                logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, i * len(x), len(train_loader.dataset),
                    100. * i / len(train_loader),
                    loss.item() / len(x)))
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, i * len(x), len(train_loader.dataset),
                    100. * i / len(train_loader),
                    loss.item() / len(x)))

            datapoint_training_losses.append(loss.item() / len(x))

            datapoint_test_loss = eval(model, test_loader, loss_fn, True)
            datapoint_test_losses.append(datapoint_test_loss)

        training_loss = running_loss / len(train_loader.dataset)
        training_losses.append(training_loss)
        logging.info('====> Epoch: {} Average training loss: {:.4f}'.format(
            epoch, training_loss))

        test_loss = eval(model, test_loader, loss_fn, None)
        test_losses.append(test_loss)
        logging.info('====> Epoch: {} Average testing loss: {:.4f}'.format(
            epoch, test_loss))

    logging.info(f"params: {str(params.__dict__)}, train settings: {str(train_settings.__dict__)}")

    return training_losses, datapoint_training_losses, test_losses, datapoint_test_losses


@torch.no_grad()
def eval(model, test_loader, loss_fn, stop=None):
    """
    Train on the test dataset
    Can be used to test model after one epoch
    """
    model.eval()
    test_losses = []
    datapoint_test_losses = []
    running_loss = 0

    if not stop:  # evaluate model on whole test dataset
        for i, (x, _) in enumerate(test_loader):
            # Forward and back prop
            x = x.to(DEVICE)
            _, z_mean, z_log_var, x_mean, x_log_var, reconstructed_x = model(x)

            loss = calculate_loss(z_mean, z_log_var, reconstructed_x, x, loss_fn)
            running_loss += loss.item()
        return running_loss / len(test_loader.dataset)

    if stop:  # evaluate model on only one batch
        it = iter(test_loader)
        x, c = next(it)
        x = x.to(DEVICE)
        _, z_mean, z_log_var, x_mean, x_log_var, reconstructed_x = model(x)

        loss = calculate_loss(z_mean, z_log_var, reconstructed_x, x, loss_fn)

        return loss.item() / len(x)


def test_image(model, x, num_epochs, latent_dim, stop=None, loss_fn=None):
    if not loss_fn:
        loss_fn = torch.nn.MSELoss(reduction='sum')

    _, z_mean, z_log_var, x_mean, x_log_var, reconstructed_x = model(x)
    loss = calculate_loss(z_mean, z_log_var, reconstructed_x, x, loss_fn)
    x = x[0]
    reconstructed_x = reconstructed_x.detach().numpy().reshape(x.shape)
    save_test_image(x, reconstructed_x, loss.item(), num_epochs, latent_dim)


def resume_training(model, optimizer, file_name):
    state = torch.load(os.path.join(PATH_TRAINING, file_name))
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    epoch = state['epoch']
    # logging.info(f"Loading model checkpoint")
    # logging.info(f"state_dict: {model.state_dict()}, optimizer: {optimizer.state_dict()}, epoch: {epoch}")
    return epoch, optimizer, model, state


def calculate_loss(z_mean, z_log_var, reconstructed_x, x, loss_fn):
    # analytical form of -KL(q_fi(z|x) || p_theta(z))
    kl_div = -0.5 * torch.sum(1 + z_log_var
                              - z_mean ** 2
                              - torch.exp(z_log_var))

    # There are some motivations as to why we use MSE:
    # https://stats.stackexchange.com/questions/347378/variational-autoencoder-why-reconstruction-term-is-same-to-square-loss
    # if data is continuous then the decoder and encoder are gaussian according to the paper, we set p(x|z) to gaussian and get the following
    # log(P(x | z)) \propto log[e^(-|x-x'|^2)] \propto |x-x'|^2
    # others use binary cross-entropy which seems to give results closer to the paper
    loss_log_likelihood = loss_fn(reconstructed_x, torch.flatten(x, start_dim=1))

    return kl_div + loss_log_likelihood


def main():
    RANDOM_SEED = 123

    # ugly solution perhaps, but needed to bind them to self in order to use in-built .__dict__ formatting to logging
    class Params:
        def __init__(self):
            self.LEARNING_RATE = 2e-2
            self.BATCH_SIZE = 100
            self.NUM_EPOCHS = 1667
            self.HIDDEN_DIMEN = 500
            self.LATENT_SPACE = 10  # {3,5,10,20,200}

    class RunningSettings:
        def __init__(self):
            self.optimizer = 'adagrad'  # adagrad or adam
            self.criterion = 'bce'  # l1, mse or bce
            self.hyperparam_search = False
            self.train = True
            self.plot = True
            self.save_checkpoint = True
            self.load_checkpoint = False
            self.load_path = '' if not self.load_checkpoint else "checkpoint_epoch_5_iter_599.pt"
            self.logging_filename = 'train.log'  # 'hyperparameter_search1.log'
            self.track_rate = 300  # how often to log batch data loss
            self.save_rate_epoch = 5  # how often to save model checkpoints per epoch
            self.save_rate_iter = 600  # how often to save model checkpoint per batch iterations
            self.start_epoch = 1  # where training epochs start, if loaded from checkpoint then it will be higher than 1
            self.device = DEVICE

    params = Params()
    running_settings = RunningSettings()

    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)

    # setup logger so can save logging.infos, call logging.info
    utils.set_logger(os.path.join(LOGGING_PATH, running_settings.logging_filename))

    train_loader, test_loader = utils.create_dataset(data_path=DATA_PATH, batch_size=params.BATCH_SIZE)
    MNIST_shape = train_loader.dataset.data.shape[1:]  # exclude batch dimension

    model = utils.create_model(mnist_shape=MNIST_shape, hidden_dimen=params.HIDDEN_DIMEN,
                               latent_space=params.LATENT_SPACE, device=DEVICE)
    criterion = utils.create_criterion(running_settings.criterion)
    optimizer = utils.create_optimizer(model=model, learning_rate=params.LEARNING_RATE, type=running_settings.optimizer)

    # load model checkpoint
    state = {}
    if running_settings.load_checkpoint:
        running_settings.start_epoch, optimizer, model, state = resume_training(model, optimizer, running_settings.load_path)

    # hyperparameter search
    if running_settings.hyperparam_search:
        params_to_optimize = {'lr': [0.01, 0.02, 0.1], 'ls': [3, 5, 10, 20, 200]}
        best_learning_rates = utils.learning_rate_hyperparam_search(parameters_to_optimize=params_to_optimize, params=params,
                                                                    train_fn=train, mnist_shape=MNIST_shape,
                                                                    train_loader=train_loader, test_loader=test_loader,
                                                                    running_settings=running_settings)
        logging.info(f"Best learning rates for different latent spaces: {str(best_learning_rates)}")

    # train
    if running_settings.train:
        tic = time()
        training_losses, datapoint_training_losses, test_losses, datapoint_test_losses = \
            train(model=model,
                  optimizer=optimizer,
                  train_loader=train_loader,
                  test_loader=test_loader,
                  loss_fn=criterion,
                  params=params,
                  train_settings=running_settings,
                  state=state)
        final_time = time() - tic
        logging.info('Done (t={:0.2f}m)'.format(final_time / 60))

    # plot evolution of loss along epochs/datapoint
    if running_settings.plot:
        plot_epoch_losses(training_losses, test_losses, params.LATENT_SPACE, params.NUM_EPOCHS)
        plot_datapoint_losses(datapoint_training_losses, datapoint_test_losses, params.LATENT_SPACE, params.NUM_EPOCHS)

        # run model on one image to test
        # inputs, classes = next(iter(test_loader))
        # x = inputs[np.random.randint(len(inputs))]
        # test_image(model, x, params.LATENT_SPACE, params.NUM_EPOCHS, stop=None, loss_fn=criterion)
        #
        # # plot manifold for latent dimension of 2
        # plot_manifold(model, DEVICE, n=12)


In [None]:
main()

