In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.optim

# CODE FILES HERE
from examples.vae.vae import Encoder, Decoder, Vae, PATH
from solver import Solver
from dataloader import DataLoader
from plot import plot_losses, plot_gaussian_distributions, plot_rl_kl, plot_latent_space, plot_latent_space_no_labels, \
plot_latent_manifold, plot_faces_grid, plot_faces_samples_grid

%matplotlib inline
#plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
#plt.rcParams['image.interpolation'] = 'nearest'
#plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

# supress cluttering warnings in solutions
import warnings
warnings.filterwarnings('ignore')

In [2]:
# setting device on GPU if available, else CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cpu



In [3]:
# Choose the dataset and tune hyperparameters here!
dataset = "FF"
batch_size = 128
optimizer = torch.optim.Adam

# TODO: set individually for each dataset
warmup_epochs = 0
beta = 1

if dataset == "MNIST":
    epochs = 100
    hidden_dim = 500
    z_dim = 20
    step_config = {
        "step_size" : -1,
        "gamma" : 0.1 # or 0.75
    }
    optim_config = {
        "lr": 1e-3,
        "weight_decay": 1e-4
    }
elif dataset == "LFW":
    epochs = 500
    hidden_dim = 700
    z_dim = 20
    step_config = {
        "step_size" : 30,
        "gamma" : 0.1
    }
    optim_config = {
        "lr": 1e-1,
        "weight_decay": 1e-4
    }
elif dataset == "FF":
    epochs = 100
    hidden_dim = 200
    z_dim = 20
    step_config = {
        "step_size" : -1,
        "gamma" : 0.1
    }
    optim_config = {
        "lr": 1e-3,
        "weight_decay": 1e-4
    }

In [4]:
data_loader = DataLoader(PATH, batch_size, dataset, z_dim)
encoder = Encoder(data_loader.input_dim, hidden_dim, z_dim)
decoder = Decoder(z_dim, hidden_dim, data_loader.input_dim)
model = Vae(encoder, decoder)

In [None]:
solver = Solver(model, data_loader, optimizer, z_dim, epochs, step_config, optim_config, warmup_epochs, beta, batch_size)
solver.run()

+++++ START RUN +++++
====> Epoch: 1 train set loss avg: 407.7662
====> Test set loss avg: 387.9781
0.533930778503418 seconds for epoch 1
====> Epoch: 2 train set loss avg: 400.7962
====> Test set loss avg: 389.4409
0.47655320167541504 seconds for epoch 2
====> Epoch: 3 train set loss avg: 396.7981
====> Test set loss avg: 392.4545
0.4870600700378418 seconds for epoch 3
====> Epoch: 4 train set loss avg: 394.1032
====> Test set loss avg: 394.9397
0.563291072845459 seconds for epoch 4
====> Epoch: 5 train set loss avg: 392.3239
====> Test set loss avg: 388.3488
0.5656392574310303 seconds for epoch 5
====> Epoch: 6 train set loss avg: 390.5347
====> Test set loss avg: 389.8269
0.5661039352416992 seconds for epoch 6
====> Epoch: 7 train set loss avg: 389.1505
====> Test set loss avg: 388.4996
0.4705629348754883 seconds for epoch 7
====> Epoch: 8 train set loss avg: 387.7406
====> Test set loss avg: 386.9929
0.479968786239624 seconds for epoch 8
====> Epoch: 9 train set loss avg: 386.4575


====> Epoch: 70 train set loss avg: 354.5770
====> Test set loss avg: 357.1261
0.7017498016357422 seconds for epoch 70
====> Epoch: 71 train set loss avg: 354.3965
====> Test set loss avg: 356.0880
0.6069254875183105 seconds for epoch 71
====> Epoch: 72 train set loss avg: 354.1720
====> Test set loss avg: 355.5947
0.5505437850952148 seconds for epoch 72
====> Epoch: 73 train set loss avg: 354.0570
====> Test set loss avg: 356.4599
0.5645318031311035 seconds for epoch 73
====> Epoch: 74 train set loss avg: 353.9031
====> Test set loss avg: 355.9669
0.5556015968322754 seconds for epoch 74
====> Epoch: 75 train set loss avg: 353.7036
====> Test set loss avg: 355.1690
0.5636999607086182 seconds for epoch 75
====> Epoch: 76 train set loss avg: 353.5197
====> Test set loss avg: 355.8105
0.5658524036407471 seconds for epoch 76
====> Epoch: 77 train set loss avg: 353.4280
====> Test set loss avg: 355.6719
0.5635559558868408 seconds for epoch 77
====> Epoch: 78 train set loss avg: 353.2346
===

In [None]:
# Insert name of model here if want to load a model, e.g. "../models/VAE_MNIST_train_loss=151.39_z=2.pt"
#solver = torch.load("../models/VAE_MNIST_train_loss=97.15_z=20.pt")
#solver.model.eval()

In [None]:
# Plotting train and test losses for all epochs
plot_losses(solver, 4)

In [None]:
plot_gaussian_distributions(solver)

In [None]:
# Monitoring the reconstruction loss (likelihood lower bound) and KL divergence
DEBUG = 0
if DEBUG:
    for epoch, train_loss, test_loss, rl, kl in zip(solver.train_loss_history["epochs"], \
                             solver.train_loss_history["train_loss_acc"], \
                             solver.test_loss_history, \
                             solver.train_loss_history["recon_loss_acc"], \
                             solver.train_loss_history["kl_diverg_acc"]):
        print("epoch: {}, train_loss: {:.2f}, test_loss: {:.2f}, recon. loss: {:.2f}, KL div.: {:.2f}".format(
            epoch, train_loss, test_loss, rl, kl))
        print("overfitting: {:.2f}".format(abs(test_loss-train_loss)))
plot_rl_kl(solver, 4)

In [None]:
# visualize q(z) (latent space z)
if solver.z_dim == 2:
    if solver.loader.dataset == "FF":
        plot_latent_space_no_labels(solver)
    else:
        plot_latent_space(solver)
else:
    print("Plot of latent space not possible as dimension of z is not 2")

In [None]:
# Visualizations of learned data manifold for generative models with two-dimensional latent space
if solver.z_dim == 2:
    if solver.loader.dataset == "MNIST":
        plot_latent_manifold(solver, "bone")
    if solver.loader.dataset == "LFW" or solver.loader.dataset == "FF":
        plot_latent_manifold(solver, "gray", n=10, fig_size=(10, 8))
else:
    print("Plot is not possible as dimension of z is not 2")

In [None]:
# plots real faces and in grid samples
if solver.loader.dataset == "LFW" or solver.loader.dataset == "FF":
    plot_faces_grid(225, 15, solver)
    plot_faces_samples_grid(225, 15, solver)

In [None]:
last_train_loss = solver.train_loss_history["train_loss_acc"][-1]
torch.save(solver, "../models/VAE_" + solver.loader.dataset + "_train_loss=" + "{0:.2f}".format(last_train_loss) + "_z=" + str(solver.z_dim) + ".pt")