In [1]:
# importing modules


import pickle
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

In [2]:
# hyperparameter settings


batch_size = 100
num_threads = 8
lat_dim = 2
lr = 1e-2
n_epoch = 1600

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
# dataloaders settings


def MNIST_loader():
    train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    test_data = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_threads)
    return train_loader, test_loader
  
def FreyFace_loader():
    raw_data = scipy.io.loadmat('data/Frey_Face/frey_rawface.mat')
    raw_data = raw_data["ff"].T.reshape((-1, 1, 28, 20))
    raw_data = raw_data.astype('float32')/255
    train_data = raw_data[:-100]
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    test_data = raw_data[-100:]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    return train_loader, test_loader


In [4]:
# several utils : mostly are for the visualizations


def plot_ELBO_curve(generated, lat_dim, train_list, test_list, ylim):
    fig = plt.figure()       
    plt.plot(list(range(1, 1600*60000, 60000)), train_list, 'r-', label='train ELBO')
    plt.plot(list(range(1, 1600*60000, 60000)), test_list, 'r--', label='test ELBO')
    plt.legend()
    plt.xscale('log')
    plt.xlim(10**5, 10**8)
    plt.ylim(ylim)
    plt.savefig('plots/'+generated+'/loss_curves_{}.png'.format(lat_dim))
    plt.close()

def plot_test_images(sample_hat, generated, lat_dim, height, width, n=10):
    sample_hat = sample_hat.view(n*n, 1, height, width)
    grid = save_image(sample_hat, fp="plots/"+generated+"/test_images_{}.png".format(lat_dim), nrow=n, padding=0)
    
def plot_latent_space(model, epoch, generated, height, width, intNumX1, intNumX2):
    z1s = np.linspace(0.0, 1.0, intNumX1, dtype=np.float32)
    z2s = np.linspace(0.0, 1.0, intNumX2, dtype=np.float32)
    dist = torch.distributions.normal.Normal(loc=0, scale=1)    
    canvas = np.zeros((height * intNumX1, width * intNumX2))

    for j, z2 in enumerate(z2s):
        for i, z1 in enumerate(z1s):
            z = torch.tensor([[z1, z2]]).to(device)
            z_hat = dist.icdf(z)
            x_hat = model(z_hat).detach().to(torch.device("cpu")).numpy()
            canvas[height*(intNumX2-j-1):height*(intNumX2-j), width*i:width*(i+1)] = x_hat.reshape(height, width)

    plt.figure(figsize=(8, 8))
    plt.imshow(canvas, cmap='gray')
    plt.tight_layout()
    plt.savefig("plots/"+generated+"/latent_spaces_%d.png" % (epoch), bbox_inches='tight')
    plt.close()


In [5]:
# The model class of Variational autoencoder


class VariationalAE(nn.Module):

    def __init__(self, generated, inp_dim, hid_dim, mode='learn'):
        super(VariationalAE, self).__init__()
        self.generated = generated
        self.mode = mode
        
        self.x_to_h = nn.Linear(inp_dim, hid_dim)
        self.h_to_z_mu = nn.Linear(hid_dim, lat_dim)
        self.h_to_z_log_var = nn.Linear(hid_dim, lat_dim)
        
        self.z_to_h = nn.Linear(lat_dim, hid_dim)
        self.h_to_x = nn.Linear(hid_dim, inp_dim)
        self.h_to_x_mu = nn.Linear(hid_dim, inp_dim)
        self.h_to_x_log_var = nn.Linear(hid_dim, inp_dim)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def reparametrize(self, z_mu, z_log_var):
        z_std = torch.exp(0.5 * z_log_var)
        eps = torch.randn(z_std.size()).to(device)
        return z_mu + z_std * eps

    def encoderNet(self, x):
        h = self.tanh(self.x_to_h(x))
        z_mu = self.h_to_z_mu(h)
        z_log_var = self.h_to_z_log_var(h)
        z = self.reparametrize(z_mu, z_log_var)
        self.kl = -0.5 * ((1 + z_log_var) - z_mu * z_mu - torch.exp(z_log_var)).mean(dim=0).sum()
        return z

    def decoderNet(self, z):
        h_ = self.tanh(self.z_to_h(z))
        if self.generated == 'Bernoulli':
            x_p = self.sigmoid(self.h_to_x(h_))
            return x_p

        elif self.generated == 'Gaussian':
            x_mu = self.sigmoid(self.h_to_x_mu(h_))
            x_log_var = self.h_to_x_log_var(h_)
            x_std = torch.exp(0.5 * x_log_var)
            return x_mu, x_std
            
    def forward(self, x):
        if self.mode == 'learn':
            self.z = self.encoderNet(x)
            if self.generated == 'Bernoulli':
                x_p = self.decoderNet(self.z)
                self.recon = -(x * torch.log(x_p) + (1 - x) * torch.log(1 - x_p)).mean(dim=0).sum()
            elif self.generated == 'Gaussian':
                x_mu, x_std = self.decoderNet(self.z)
                self.dist = torch.distributions.normal.Normal(loc=x_mu, scale=x_std)
                self.recon = -self.dist.log_prob(x).mean(dim=0).sum()
            return self.recon, self.kl

        elif self.mode == 'generate':
            z = x
            if self.generated == 'Bernoulli':
                x_p = self.decoderNet(z)
                x_hat = x_p
            elif self.generated == 'Gaussian':
                x_mu, x_std = self.decoderNet(z)
                x_hat = x_mu
            return x_hat
        

In [6]:
# Function to train VAE with MNIST dataset


def MNIST_main(dist='Bernoulli'):
    train_loader, test_loader = MNIST_loader()
    model = VariationalAE(generated=dist, inp_dim=784, hid_dim=500).to(device)
    
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    train_list, test_list = [], []
    
    for epoch in range(n_epoch):
        #if (epoch+1) % 200 == 0:
        #    lr = lr * 0.1
        #    print('current learning rate is ', lr)
        optimizer = optim.Adagrad(model.parameters(), lr=lr)
        
        train_loss, recon_err, kl_div, test_loss = 0.0, 0.0, 0.0, 0.0
            
        cnt = 0
        for x, y in train_loader:
            cnt += 1
            model.train()
            inputs = x.to(device)
            new_batch_size = x.size()[0]
            inputs = inputs.view(new_batch_size, -1)
            
            recon, kl = model(inputs)
            loss = recon + kl          
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss * new_batch_size / batch_size
            recon_err += recon * new_batch_size / batch_size
            kl_div += kl * new_batch_size / batch_size
        
        train_loss = train_loss / cnt
        recon_err = recon_err / cnt
        kl_div = kl_div / cnt
        train_list.append(-train_loss)
        
        with torch.no_grad():
            cnt = 0
            for x, y in test_loader:
                cnt += 1
                model.eval()
                inputs = x.to(device)
                new_batch_size = x.size()[0]
                inputs = inputs.view(new_batch_size, -1)
            
                recon, kl = model(inputs)
                loss = recon + kl
           
                test_loss += loss * new_batch_size / batch_size

            test_loss = test_loss / cnt
            test_list.append(-test_loss)
        
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, recon: %.3f, kl_div: %.3f, test_loss : %3f'
                  % (epoch+1, train_loss, recon_err, kl_div, test_loss))
             
    torch.save(model.state_dict(), 'models/'+dist+'/VariationalAE_%d.pt' % (lat_dim))
    
    #with open('loss/'+dist+'/train_ELBO_%d.txt' % (lat_dim), 'wb') as f:
    #    pickle.dump(train_list, f)
    #with open('loss/'+dist+'/test_ELBO_%d.txt' % (lat_dim), 'wb') as f:
    #    pickle.dump(test_list, f)
    
    plot_ELBO_curve(model.generated, lat_dim, train_list, test_list, (-150, -90))

    if lat_dim == 2:
        model.mode = 'generate'
        plot_latent_space(model, epoch+1, model.generated, height=28, width=28, intNumX1=20, intNumX2=20)


In [7]:
# function to train VAE with FreyFace dataset


def FreyFace_main(dist='Gaussian'):
    train_loader, test_loader = FreyFace_loader()
    model = VariationalAE(generated=dist, inp_dim=560, hid_dim=200).to(device)
    
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    train_list, test_list = [], []
    
    for epoch in range(n_epoch):
        #if (epoch+1) % 200 == 0:
        #    lr = lr * 0.1
        #    print('current learning rate is ', lr)
        optimizer = optim.Adagrad(model.parameters(), lr=lr)
        
        train_loss, recon_err, kl_div, test_loss = 0.0, 0.0, 0.0, 0.0
        cnt = 0
        for x in train_loader:
            cnt += 1
            inputs = x.to(device)
            new_batch_size = x.size()[0]
            inputs = inputs.view(new_batch_size, -1)
            
            recon, kl = model(inputs)
            loss = recon + kl
            
            train_loss += loss * new_batch_size / batch_size
            recon_err += recon * new_batch_size / batch_size
            kl_div += kl * new_batch_size / batch_size
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
      
        train_loss = train_loss / cnt
        recon_err = recon_err / cnt
        kl_div = kl_div / cnt
        
        train_list.append(-train_loss)

        with torch.no_grad():
            cnt = 0
            for x in test_loader:
                cnt += 1
                model.eval()
                inputs = x.to(device)
                new_batch_size = x.size()[0]
                inputs = inputs.view(new_batch_size, -1)
            
                recon, kl = model(inputs)
                loss = recon + kl
           
                test_loss += loss * new_batch_size / batch_size

            test_loss = test_loss / cnt
            test_list.append(-test_loss)
        
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, recon_err: %.3f, kl_div: %.3f, test_loss: %.3f'
                  % (epoch+1, train_loss, recon_err, kl_div, test_loss))
                
    torch.save(model.state_dict(), 'models/'+dist+'/VariationalAE%d.pt' % (lat_dim))
    
    with open('loss/'+dist+'/train_ELBO_%d.txt' % (lat_dim), 'wb') as f:
        pickle.dump(train_list, f)
    with open('loss/'+dist+'/train_ELBO_%d.txt' % (lat_dim), 'wb') as f:
        pickle.dump(test_list, f)
    
    plot_ELBO_curve(model.generated, lat_dim, train_list, test_list, (0, 1600))
    
    if lat_dim == 2:
        model.mode = 'generate'
        plot_latent_space(model, epoch+1, model.generated, height=28, width=20, intNumX1=10, intNumX2=10)



In [8]:
#Training and visualization can be done with customized setting of learning/hyper parameters

# 1st experiment : plotting the ELBO curve of Variational Autooencoder in case of MNIST and FreyFace
#                  ( this is accomplished by plot_ELBO_curve in utils section used in main sections.
# 2nd experiment : drawing a latent space manifold 
#                  ( this is accomplished by plot_latent_space functions in utils section used in main sections.)

if __name__ == '__main__': 
    lr = 1e-3
    lat_dim = 2
    print(device)
    n_epoch = 10


    MNIST_main()
    FreyFace_main()


cuda:1
[Epoch 1] train_loss: 210.791, recon: 203.383, kl_div: 7.408, test_loss : 195.003250
[Epoch 2] train_loss: 187.474, recon: 182.567, kl_div: 4.907, test_loss : 184.667892
[Epoch 3] train_loss: 182.348, recon: 177.762, kl_div: 4.586, test_loss : 181.061768
[Epoch 4] train_loss: 180.256, recon: 175.720, kl_div: 4.536, test_loss : 179.367874
[Epoch 5] train_loss: 178.872, recon: 174.313, kl_div: 4.559, test_loss : 178.071106
[Epoch 6] train_loss: 177.651, recon: 173.027, kl_div: 4.623, test_loss : 176.859802
[Epoch 7] train_loss: 176.500, recon: 171.785, kl_div: 4.715, test_loss : 175.743958
[Epoch 8] train_loss: 175.391, recon: 170.639, kl_div: 4.753, test_loss : 174.720764
[Epoch 9] train_loss: 174.335, recon: 169.521, kl_div: 4.814, test_loss : 173.638458
[Epoch 10] train_loss: 173.373, recon: 168.491, kl_div: 4.882, test_loss : 172.718323
