In [None]:
import numpy as np
import torch
from torch import distributions
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T
from torchvision.utils import make_grid, save_image
import time
from PIL import Image
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
%matplotlib inline

%pip install pytorch-ignite
%pip install --pre pytorch-ignite
%pip install torchsummary

from ignite.metrics import FID
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

to_pil_image = T.ToPILImage()

# Definition of VAE with Gaussian Prior and Flow-Based Prior (RealNVP)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size = 300, latent_size = 100):
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, 2 * latent_size)
        )
    
    def sample(self, mu, sigma):
        eps = torch.randn_like(sigma)
        return mu + sigma * eps
        
    def forward(self, x):
        h = self.encoder(x)

        mu, log_var = torch.chunk(h, 2, dim=1)
        sigma = torch.exp(0.5*log_var) 
        z = self.sample(mu, sigma)
        
        return z, mu, sigma

class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size = 300, latent_size = 100):
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.decoder(x)             

class VAE(nn.Module):
    def __init__(self, input_size):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_size)
        self.decoder = Decoder(input_size)
        
    def forward(self, x):
        z, mu, sigma = self.encoder(x)
                
        return self.decoder(z), z, mu, sigma
    
class GaussianPrior(nn.Module):
    def __init__(self, latent_size=100):
        super(GaussianPrior, self).__init__()
        self.latent_size = latent_size

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.latent_size))
        return z

    def log_prob(self, z):
        PI = torch.from_numpy(np.asarray(np.pi)).to(device)
        return -0.5 * torch.log(2. * PI) - 0.5 * z**2.

class FlowPrior(nn.Module):
    def __init__(self, nets, nett, num_flows, D=2):
        super(FlowPrior, self).__init__()

        self.D = D

        self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)])
        self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)])
        self.num_flows = num_flows

    def coupling(self, x, index, forward=True):
        (xa, xb) = torch.chunk(x, 2, 1)

        s = self.s[index](xa)
        t = self.t[index](xa)

        if forward:
            #yb = f^{-1}(x)
            yb = (xb - t) * torch.exp(-s)
        else:
            #xb = f(y)
            yb = torch.exp(s) * xb + t

        return torch.cat((xa, yb), 1), s

    def permute(self, x):
        return x.flip(1)

    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in range(self.num_flows):
            z, s = self.coupling(z, i, forward=True)
            z = self.permute(z)
            log_det_J = log_det_J - s.sum(dim=1)

        return z, log_det_J

    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x, _ = self.coupling(x, i, forward=False)

        return x

    def sample(self, batch_size):
        z = torch.randn(batch_size, self.D)
        x = self.f_inv(z)
        return x.view(-1, self.D)

    def log_prob(self, x):
        z, log_det_J = self.f(x)
        
        PI = torch.from_numpy(np.asarray(np.pi))
        log_standard_normal = -0.5 * torch.log(2. * PI) - 0.5 * z**2.
        
        log_p = (log_standard_normal + log_det_J.unsqueeze(1))
        return -log_p
    
    
class ELBO():
    def __init__(self, prior):
        self.prior = prior
        self.reconstruction_error = nn.BCELoss(reduction='none')
    
    def kullback_Leibler_divergence(self, z, mu, sigma):
        q = torch.distributions.Normal(mu, sigma)

        log_qz = q.log_prob(z)
        log_pz = self.prior.log_prob(z)
        
        kl = (log_qz - log_pz).sum(-1)
        
        return kl
    
    def __call__(self, inputs, outputs, z, mu, sigma):
        
        re = self.reconstruction_error(outputs, inputs).sum(-1)
        kl = self.kullback_Leibler_divergence(z, mu, sigma)

        elbo = (re + kl)
        return elbo.mean()

In [None]:
def train(net, prior, train_data, val_data, img_dim, batch_size=10, learning_rate=0.0001, epochs=20, nr_test_samples=64, img_dir='None'):
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    criterion = ELBO(prior)
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)

    running_loss = []
    running_val_loss = []
#     m = FID()

    for epoch in range(epochs):
        print("Epoch:" , epoch+1)

        for i, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
            optimizer.zero_grad()

            inputs, _ = data
            inputs = inputs.to(device)

            # Forward
            outputs, z, mu, sigma = net(inputs)
            
            # Backward
            loss = criterion(inputs, outputs, z, mu, sigma)
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.item())
            
        sample = prior.sample(nr_test_samples)
        generated_img = net.decoder(sample).view(nr_test_samples,img_dim,32,32)
        generated_img = make_grid(generated_img)
        
        im = Image.fromarray(np.array(to_pil_image(generated_img)))
        im.save(f"{img_dir}/epoch_{epoch}.jpeg")
        
        with torch.no_grad():
            for inputs, _ in val_loader:
                inputs = inputs.to(device)  
                outputs, z, mu, sigma  = net(inputs)
                loss = criterion(inputs, outputs, z, mu, sigma)
                running_val_loss.append(loss)
#                 m.update(outputs, inputs)
                
        print(f'Train Loss: {np.mean(running_loss[-len(train_data):])} | Validation Loss: {np.mean(running_val_loss[-len(val_data):])}')
#         print(m.compute())
    return running_loss, running_val_loss

In [None]:
def plot_interpolated(net, img_dim, n=10):
    w = 32
    img = np.zeros((n*w, n*w, img_dim))
    
    s, e1, e2 = prior.sample(3)
    
    for i, y in enumerate(np.linspace(0, 1, n)):
        for j, x in enumerate(np.linspace(0, 1, n)):
            
            z = s + (e1-s) * x + (e2-s) * y
            generated_image = net.decoder(z.view(-1,100)).view(img_dim, 32,32)

            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w,:] = generated_image.detach().numpy().transpose((1,2,0))
    plt.imshow(img)

## Standard VAE on MNIST

In [None]:
mnist_train = datasets.MNIST(root='data', train=True, download=True, transform=T.Compose([T.Resize(32), T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
input_size = 32*32
batch_size=64
net1 = VAE(input_size).to(device)
prior = GaussianPrior()
print("Number of trainable parameters in VAE:", count_parameters(net1))
print("Number of trainable parameters in Prior:", count_parameters(prior))
train_set, val_set = torch.utils.data.random_split(mnist_train, [55000, 5000])
train_loss, val_loss = train(net1, prior, train_set, val_set, 1, batch_size=batch_size, epochs=50, img_dir='VAE/MNIST')

In [None]:
train_loss = np.array(train_loss).reshape(-1, int(55000/64)+1).mean(axis=1)
val_loss = np.array(val_loss).reshape(-1, int(5000/64)+1).mean(axis=1)

plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss for VAE on MNIST')
plt.legend()
plt.show()

In [None]:
plot_interpolated(net1, 1)

## Standard VAE on SVHN

In [None]:
svhn_train = datasets.SVHN(root='data', download=True, transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))

In [None]:
input_size = 32*32*3
net2 = VAE(input_size).to(device)
prior = GaussianPrior()
print("Number of trainable parameters in VAE:", count_parameters(net2))
print("Number of trainable parameters in Prior:", count_parameters(prior))
train_set, val_set = torch.utils.data.random_split(svhn_train, [65000, 8257])
train_loss, val_loss = train(net2, prior, train_set, val_set, 3, batch_size=batch_size, epochs=50, img_dir='VAE/SVHN')

In [None]:
train_loss = np.array(train_loss).reshape(-1, int(65000/64)+1).mean(axis=1)
val_loss = np.array(val_loss).reshape(-1, int(8257/64)+1).mean(axis=1)

plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss for VAE on SVHN')
plt.legend()
plt.show()

In [None]:
plot_interpolated(net2, 3)

## RealNVP on MNIST

In [None]:
num_flows = 3
L = 100
M = 300

nets = lambda: nn.Sequential(nn.Linear(L // 2, M), nn.LeakyReLU(),
                            nn.Linear(M, M), nn.LeakyReLU(),
                            nn.Linear(M, L // 2), nn.Tanh())

nett = lambda: nn.Sequential(nn.Linear(L // 2, M), nn.LeakyReLU(),
                            nn.Linear(M, M), nn.LeakyReLU(),
                            nn.Linear(M, L // 2))

prior = RealNVP(nets, nett, num_flows=num_flows, D=L)

batch_size = 64

In [None]:
input_size = 32*32
net3 = VAE(input_size).to(device)
print("Number of trainable parameters in VAE:", count_parameters(net3))
print("Number of trainable parameters in Prior:", count_parameters(prior))
train_set, val_set = torch.utils.data.random_split(mnist_train, [55000, 5000])
train_loss, val_loss = train(net3, prior, train_set, val_set, 1, batch_size=batch_size, epochs=50, img_dir='RealNVP/MNIST')

In [None]:
train_loss = np.array(train_loss).reshape(-1, int(55000/64)+1).mean(axis=1)
val_loss = np.array(val_loss).reshape(-1, int(5000/64)+1).mean(axis=1)

plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss for RealNVP on MNIST')
plt.legend()
plt.show()

In [None]:
plot_interpolated(net3, 1)

# RealNVP on SVHN

In [None]:
input_size = 32*32*3

net4 = VAE(input_size).to(device)
print("Number of trainable parameters in VAE:", count_parameters(net4))
print("Number of trainable parameters in Prior:", count_parameters(prior))
train_set_svhn, val_set_svhn = torch.utils.data.random_split(svhn_train, [65000, 8257])
train_loss_svhn, val_loss_svhn = train(net4, prior, train_set_svhn, val_set_svhn, 3, batch_size=batch_size, epochs=50, img_dir='RealNVP/SVHN')

In [None]:
train_loss_epoch = np.array(train_loss_svhn).reshape(-1, int(60000/64)+1).mean(axis=1)
val_loss_epoch = np.array(val_loss_svhn).reshape(-1, int(13257/64)+1).mean(axis=1)

plt.plot(train_loss_epoch, label='Training')
plt.plot(val_loss_epoch, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss for RealNVP on SVHN')
plt.legend()
plt.show()

In [None]:
plot_interpolated(net4, 3)