In [13]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
import os
from torchvision import datasets, transforms
from torch import optim, nn, unsqueeze
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Lambda, Compose
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import torchvision.utils as vutils
import torch.nn as nn
from torch.distributions.log_normal import LogNormal
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# VAE

### FID

In [15]:
# import numpy
# from numpy import cov
# from numpy import trace
# from numpy import iscomplexobj
# from numpy import asarray
# from numpy.random import shuffle
# from scipy.linalg import sqrtm
# from keras.applications.inception_v3 import InceptionV3
# from keras.applications.inception_v3 import preprocess_input
# from keras.datasets.mnist import load_data
# from skimage.transform import resize
# from torchvision.datasets import SVHN
 
# # scale an array of images to a new size
# def scale_images(images, new_shape):
# 	images_list = list()
# 	for image in images:
# 		# resize with nearest neighbor interpolation
# 		new_image = resize(image, new_shape, 0)
# 		# store
# 		images_list.append(new_image)
# 	return asarray(images_list)
 
# # calculate frechet inception distance
# def calculate_fid(model, images1, images2):
# 	# calculate activations
# 	act1 = model.predict(images1)
# 	act2 = model.predict(images2)
# 	# calculate mean and covariance statistics
# 	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
# 	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
# 	# calculate sum squared difference between means
# 	ssdiff = numpy.sum((mu1 - mu2)**2.0)
# 	# calculate sqrt of product between cov
# 	covmean = sqrtm(sigma1.dot(sigma2))
# 	# check and correct imaginary numbers from sqrt
# 	if iscomplexobj(covmean):
# 		covmean = covmean.real
# 	# calculate score
# 	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
# 	return fid
 
# # prepare the inception v3 model
# model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))
# # load cifar10 images
# (images1, _), (images2, _) = SVHN.load_data()
# shuffle(images1)
# images1 = images1[:10000]
# print('Loaded', images1.shape, images2.shape)
# # convert integer to floating point values
# images1 = images1.astype('float32')
# images2 = images2.astype('float32')
# # resize images
# images1 = scale_images(images1, (299,299,3))
# images2 = scale_images(images2, (299,299,3))
# print('Scaled', images1.shape, images2.shape)
# # pre-process images
# images1 = preprocess_input(images1)
# images2 = preprocess_input(images2)
# # calculate fid
# fid = calculate_fid(model, images1, images2)
# print('FID: %.3f' % fid)

## MNIST

In [16]:
BATCH_SIZE = 64

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]) 
# transform = transforms.Compose([transforms.Resize((32, 32)), 
#                                 transforms.ToTensor(),
#                                 transforms.Lambda(lambda x: x.repeat(3,1,1)),
#                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#load in full train set
trainsetfull = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
# type(trainsetfull)

# data loader for final run 
trainfullloader = torch.utils.data.DataLoader(trainsetfull, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#split the set 
trainset, valset = torch.utils.data.random_split(trainsetfull, [55000, 5000])
#load in test set
testset = torchvision.datasets.MNIST(root='./data/mnist', train=False, transform=transform,download=True)


# data loader for training
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)
# data loader for validation
valiloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)
# data loader for testing
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False, num_workers=2)

## Model

In [18]:
#FROM SLIDES

class VAE(nn.Module):
    def __init__(self, D, M):
        super(VAE, self).__init__()
        self.D = D
        self.M = M

        self.enc1 = nn.Linear(in_features=self.D, out_features=300)
        self.enc2 = nn.Linear(in_features=300, out_features=self.M*2)

        self.dec1 = nn.Linear(in_features=self.M, out_features=300)
        self.dec2 = nn.Linear(in_features=300, out_features=self.D)

    def reparameterize(self, mu, log_std): 
        std = torch.exp(log_std)
        eps = torch.randn_like(std)
        z = mu + (eps * std)
        return z

    def forward(self, x): # encoder
        # x = nn.functional.relu(self.enc1(x))
        # x = self.enc2(x).view(-1, 2, self.M)
        
        # encoder
        x = self.enc1(x)
        x = nn.functional.relu(x)
        x = self.enc2(x).view(-1, 2, self.M)
        
        # get mean and log-std
        mu = x[:, 0, :]
        log_std = x[:, 1, :]
        
        # reparameterization
        z = self.reparameterize(mu, log_std)

        # # decoder
        # x_hat = nn.functional.relu(self.dec1(z)) 
        # x_hat = self.dec2(x)
        # return x_hat, mu, log_std

        # decoder
        x_hat = nn.functional.relu(self.dec1(z))
        x_hat = torch.sigmoid(self.dec2(x_hat))
        return x_hat, mu, z, log_std

    def generate(self, z):
        x_hat = nn.functional.relu(self.dec1(z))
        x_hat = torch.sigmoid(self.dec2(x_hat))
        return x_hat

    def elbo(self, x, x_hat, z, mu, log_std): 
        # reconstruction error
        # RE = nn.loss.mse(x, x_hat)

        RE = F.binary_cross_entropy(x_hat, x)

        # kl-regularization
        # We assume here that log_normal is implemented
        # KL = LogNormal(z, mu, log_std) - LogNormal(z, 0, 1)

        KL = -0.5 * torch.sum(1 + log_std - mu.pow(2) - log_std.exp())
        KL /= (784 * x_hat.size(0))

        # REMEMBER! We maximize ELBO, but optimizers minimize. # Therefore, we need to take the negative sign!
        return -(RE - KL)


In [19]:
def fit(model, dataloader, optimizer):
    model.train()
    running_loss = 0.0

    for i, data in enumerate(dataloader):

        data = data[0].to(device)
        data = data.view(data.size(0),-1)

        optimizer.zero_grad()
        data_hat, mu,z, log_std = model(data)
        loss = model.elbo(x = data, x_hat = data_hat,z=z, mu = mu, log_std = log_std)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

    #from google.colab import files
def validate(model, valdataloader, epoch):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(valdataloader):
            #data, _ = data
            data = data[0].to(device)
            data = data.view(data.size(0),-1)
            data_hat, mu, z, log_std = model(data)
            loss = model.elbo(data, data_hat,z, mu, log_std)
            running_loss += loss.item()

            if (epoch%10 == 0) and (i == len(valiloader)-1):
                # save the last 8 samples input and output of every 10th epoch
                num_rows = 8
                both = torch.cat((data.view(len(data), 1, 28, 28)[:8], 
                                  data_hat.view(len(data), 1, 28, 28)[:8]))
                # both = torch.cat((data.view(len(data), 3, 32, 32)[:8], 
                #                   data_hat.view(len(data), 3, 32, 32)[:8]))
                torchvision.utils.save_image(both.cpu(), f"./output{epoch}.png", nrow=num_rows)
    
    #val_loss = running_loss/len(valiloader.dataset)
    return running_loss/len(valiloader.dataset)

def train(model, trainloader, valiloader, optimizer, epochs):
  train_loss, val_loss = [], []

  for epoch in range(epochs):
      print(f"Epoch {epoch+1} of {epochs}")
      train_epoch_loss = fit(model, trainloader, optimizer)
      val_epoch_loss = validate(model, valiloader, epoch)
      train_loss.append(train_epoch_loss)
      val_loss.append(val_epoch_loss)
      print(f"Train Loss: {train_epoch_loss:.4f}")
      print(f"Val Loss: {val_epoch_loss:.4f}")

  return model, train_loss, val_loss


In [20]:
#MNIST
net = VAE(D = 28*28*1, M = 16)
# net = VAE(D = 32*32*3, M = 16)
net.to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.005)
net, train_loss, val_loss= train(net, trainloader, valiloader, optimizer, 50 )

Epoch 1 of 50
Train Loss: -1.8039
Val Loss: -1.9352
Epoch 2 of 50
Train Loss: -1.9399
Val Loss: -1.9834
Epoch 3 of 50
Train Loss: -1.9812
Val Loss: -2.0166
Epoch 4 of 50
Train Loss: -2.0096
Val Loss: -2.0477
Epoch 5 of 50
Train Loss: -2.0332
Val Loss: -2.0590
Epoch 6 of 50
Train Loss: -2.0413
Val Loss: -2.0668
Epoch 7 of 50
Train Loss: -2.0460
Val Loss: -2.0696
Epoch 8 of 50
Train Loss: -2.0489
Val Loss: -2.0716
Epoch 9 of 50
Train Loss: -2.0507
Val Loss: -2.0747
Epoch 10 of 50
Train Loss: -2.0521
Val Loss: -2.0735
Epoch 11 of 50
Train Loss: -2.0533
Val Loss: -2.0751
Epoch 12 of 50
Train Loss: -2.0539
Val Loss: -2.0761
Epoch 13 of 50
Train Loss: -2.0550
Val Loss: -2.0773
Epoch 14 of 50
Train Loss: -2.0557
Val Loss: -2.0753
Epoch 15 of 50
Train Loss: -2.0561
Val Loss: -2.0779
Epoch 16 of 50
Train Loss: -2.0567
Val Loss: -2.0776
Epoch 17 of 50
Train Loss: -2.0571
Val Loss: -2.0786
Epoch 18 of 50
Train Loss: -2.0572
Val Loss: -2.0785
Epoch 19 of 50
Train Loss: -2.0576
Val Loss: -2.0792
Ep

In [23]:
#values to plot
val_loss

[-1.9352454650878905,
 -1.983359812927246,
 -2.016566542053223,
 -2.047689262390137,
 -2.058992478942871,
 -2.066827844238281,
 -2.069573812866211,
 -2.071612649536133,
 -2.074682974243164,
 -2.0734833557128907,
 -2.0750765380859373,
 -2.0760799865722657,
 -2.0772979888916017,
 -2.0753298553466797,
 -2.0778515838623046,
 -2.0776038970947264,
 -2.0786314392089844,
 -2.078510565185547,
 -2.0791779006958007,
 -2.0772095520019533,
 -2.0757425201416018,
 -2.081010192871094,
 -2.0829079223632814,
 -2.078243975830078,
 -2.079880960083008,
 -2.0813811431884766,
 -2.080195248413086,
 -2.0820124420166017,
 -2.0801575805664063,
 -2.0804015014648436,
 -2.0779665969848633,
 -2.0805175048828124,
 -2.080138983154297,
 -2.080396682739258,
 -2.0800753936767578,
 -2.0822245544433593,
 -2.080994616699219,
 -2.0809822021484377,
 -2.0804029541015625,
 -2.0799052307128907,
 -2.078240316772461,
 -2.0797628295898436,
 -2.0780477447509766,
 -2.08035498046875,
 -2.082277133178711,
 -2.079669839477539,
 -2.07900

# SVHN

In [None]:
BATCH_SIZE = 64

transform2 = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) 

#load in full train set
trainsetfull2 = torchvision.datasets.SVHN(root='./data/svhn', split='train', download=True, transform=transform2)
print(len(trainsetfull2))

# data loader for final run 
trainfullloader2 = torch.utils.data.DataLoader(trainsetfull2, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#split the set 
trainset2, valset2 = torch.utils.data.random_split(trainsetfull2, [65000,8257])
#load in test set
testset2 = torchvision.datasets.SVHN(root='./data/svhn', split='test', transform=transform2,download=True)


# data loader for training
trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)
# data loader for validation
valiloader2 = torch.utils.data.DataLoader(valset2, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)
# data loader for testing
testloader2 = torch.utils.data.DataLoader(testset2, batch_size=BATCH_SIZE,shuffle=False, num_workers=2)