In [1]:
#importing modules


import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision.utils import make_grid, save_image

In [2]:
# setting hyperparameters


img_size = 64
batch_size = 100
num_threads = 8
lat_dim = 10

beta = 4

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [10]:
# util functions mostly for visualization purpose


def param_fix(layer):
    for param in layer.parameters():
        param.requires_grad = False

def plot_ELBO_curve(generated, train_list, test_list):
    fig = plt.figure()       
    plt.plot(train_list, 'r-', label='train ELBO')
    plt.plot(test_list, 'r--', label='test ELBO')
    plt.legend()
    plt.savefig('plots/'+generated+'/ELBO_curves.png')
    plt.close()

def plot_test_images(sample_hat, generated, n, height=img_size, width=img_size):
    sample_hat = sample_hat.view(-1, 1, height, width)
    grid = save_image(sample_hat, filename="plots/"+generated+"/test_images.png", nrow=n, padding=0)


In [4]:
# Constructing custom dataset for 2D shapes dataset also called 'Dsprites'


class ShapesDataset(Dataset):
    def __init__(self, mode, file='data/dsprites/dsprites.npz'):
        dataset = np.load(file, allow_pickle=True, encoding='latin1')
        x = torch.Tensor(dataset['imgs'])
        y = torch.Tensor(dataset['latents_classes'])
        indices = torch.randperm(len(x))
        
        self.mode = mode
        if self.mode == 'train':
            self.x = x[indices][:-5000]
            self.y = y[indices][:-5000]
        elif self.mode == 'test':
            self.x = x[indices][-5000:]
            self.y = y[indices][-5000:]
                    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return len(self.x)

In [5]:
# data loader for 2D shapes dataset


def shapes_loader():
    train_data = ShapesDataset(mode='train')
    test_data = ShapesDataset(mode='test')
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    return train_loader, test_loader

In [6]:
#contructing model class for BetaVAE


class BetaVAE(nn.Module):
    def __init__(self, supervised, generated='Bernoulli', mode='learn'):
        super(BetaVAE, self).__init__()
        self.supervised = supervised
        self.generated = generated
        self.mode = mode
        self.sigmoid = nn.Sigmoid()
        self.encoder = nn.Sequential(nn.Linear(4096, 1200),
                                     nn.ReLU(),
                                     nn.Linear(1200, 1200),
                                     nn.ReLU(),
                                     nn.Linear(1200, lat_dim*2))
                                    
        self.decoder = nn.Sequential(nn.Linear(lat_dim, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 4096))
        
        self.layer = nn.Sequential(nn.Linear(lat_dim, 5),
                                   nn.Softmax())
        
    def reparametrize(self, z_mu, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn(std.size()).to(device)
        return z_mu + std * eps

    def encoderNet(self, x):
        code = self.encoder(x)
        z_mu = code[:, :lat_dim]
        z_log_var = code[:, lat_dim:]
        z = self.reparametrize(z_mu, z_log_var)
        self.kl = -0.5 * ((1 + z_log_var) - z_mu * z_mu - torch.exp(z_log_var)).mean(dim=0).sum()
        return z

    def decoderNet(self, z):
        h = self.decoder(z)
        x_ = self.sigmoid(h)
        return x_
        
    def forward(self, x):
        if self.mode == 'learn':
            self.z = self.encoderNet(x)
            if not self.supervised:
                self.x_ = self.decoderNet(self.z)
                self.recon = -(x * torch.log(self.x_ + 1e-10) + (1 - x) * torch.log(1 - self.x_ + 1e-10)).mean(dim=0).sum()
                return self.x_, self.recon, self.kl
            else:
                self.factors = self.layer(self.z)
                return self.factors
        elif self.mode == 'generate':
            self.x_ = self.decoderNet(x)
            return self.x_

In [7]:
# main function for training BetaVAE

def Shapes_train(dist='Bernoulli'):
    train_loader, test_loader = shapes_loader()
    model = BetaVAE(supervised=False, generated=dist).to(device)
    param_fix(model.layer)
    model.load_state_dict(torch.load('models/'+dist+'/BetaVAE.pt')) 
    model.mode = 'learn'
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    train_list, test_list = [], []
    
    for epoch in range(n_epoch):
        optimizer = optim.Adagrad(model.parameters(), lr=lr)
        
        train_loss, recon_err, kl_div, test_loss = 0.0, 0.0, 0.0, 0.0
        cnt = 0
        for x, y in train_loader:
            cnt += 1
            inputs = x.to(device)
            new_batch_size = x.size()[0]
            inputs = inputs.view(new_batch_size, img_size*img_size)
            
            _, recon, kl = model(inputs)
            loss = recon + beta * kl
            
            train_loss += loss * new_batch_size / batch_size
            recon_err += recon * new_batch_size / batch_size
            kl_div += kl * new_batch_size / batch_size
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
      
        train_loss = train_loss / cnt
        recon_err = recon_err / cnt
        kl_div = kl_div / cnt
        
        train_list.append(-train_loss)

        with torch.no_grad():
            cnt = 0
            for x, y in test_loader:
                cnt += 1
                model.eval()
                inputs = x.to(device)
                new_batch_size = x.size()[0]
                inputs = inputs.view(new_batch_size, img_size*img_size)
            
                _, recon, kl = model(inputs)
                loss = recon + beta * kl
           
                test_loss += loss * new_batch_size / batch_size

            test_loss = test_loss / cnt
            test_list.append(-test_loss)
        
        torch.save(model.state_dict(), 'models/'+dist+'/BetaVAE.pt')
      
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, recon_err: %.3f, kl_div: %.3f, test_loss: %.3f'
                  % (epoch+1, train_loss, recon_err, kl_div, test_loss))
                
    #torch.save(inputs[:5], 'data/dsprites/samples.pt')
    
    sample = inputs
    sample_hat, _, _ = model(sample)
            
    plot_test_images(sample_hat.detach(), model.generated, n=10)
    
    plot_ELBO_curve(model.generated, train_list, test_list)
    

In [None]:
# running code for training BetaVAE with 2D shapes dataset

if __name__ == '__main__': 
    print(device)
    lr = 1e-3
    n_epoch = 100
    Shapes_train()

cuda:1
[Epoch 1] train_loss: 112.263, recon_err: 66.804, kl_div: 11.365, test_loss: 111.798
[Epoch 2] train_loss: 112.151, recon_err: 66.679, kl_div: 11.368, test_loss: 111.959
[Epoch 3] train_loss: 112.065, recon_err: 66.568, kl_div: 11.374, test_loss: 111.679
[Epoch 4] train_loss: 111.980, recon_err: 66.465, kl_div: 11.379, test_loss: 111.752
[Epoch 5] train_loss: 111.905, recon_err: 66.378, kl_div: 11.382, test_loss: 111.727
[Epoch 6] train_loss: 111.802, recon_err: 66.265, kl_div: 11.384, test_loss: 111.412
[Epoch 7] train_loss: 111.748, recon_err: 66.171, kl_div: 11.394, test_loss: 111.551
[Epoch 8] train_loss: 111.645, recon_err: 66.081, kl_div: 11.391, test_loss: 111.723
[Epoch 9] train_loss: 111.591, recon_err: 65.989, kl_div: 11.400, test_loss: 111.530
