In [1]:

from torchvision.datasets import utils
import torch.utils.data as data_utils
import torch
import os
import numpy as np
from torch import nn
from torch.nn.modules import upsampling
from torch.functional import F
from torch.optim import Adam
from torch.autograd import Variable
#import tensorflow as tf

def get_data_loader(dataset_location, batch_size):
    URL = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
    # start processing
    def lines_to_np_array(lines):
        return np.array([[int(i) for i in line.split()] for line in lines])
    splitdata = []
    for splitname in ["train", "valid", "test"]:
        filename = "binarized_mnist_%s.amat" % splitname
        filepath = os.path.join(dataset_location, filename)
        utils.download_url(URL + filename, dataset_location, filename = filename, md5=None)
        with open(filepath) as f:
            lines = f.readlines()
        x = lines_to_np_array(lines).astype('float32')
        x = x.reshape(x.shape[0], 1, 28, 28)
        # pytorch data loader
        dataset = data_utils.TensorDataset(torch.from_numpy(x))
        dataset_loader = data_utils.DataLoader(x, batch_size=batch_size, shuffle=splitname == "train", pin_memory=True)
        splitdata.append(dataset_loader)
    return splitdata

epochs = 20
bs = 128

train, valid, test = get_data_loader("binarized_mnist", bs)

cuda = True if torch.cuda.is_available() else False

import matplotlib
import matplotlib.pyplot as plt
# %matplotlib inline
for x in train:
    plt.imshow(x[0, 0])
    break

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=256):
        return input.view(input.size(0), size, 1, 1)

class VAE(nn.Module):
  
    def __init__(self, image_channels=1, h_dim=256, z_dim=100):
        super(VAE, self).__init__()
        
        # Q(z|X) -- encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=3),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 256, kernel_size=5),
            nn.ELU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        # P(X|z) -- decoder
        self.decoder = nn.Sequential(
            
            UnFlatten(),
            nn.ELU(),
            nn.Conv2d(256, 64, kernel_size= 5, padding= 4),
            nn.ELU(),
            nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners=True), #, align_corners=True
            nn.Conv2d(64, 32, kernel_size=3, padding=2),
            nn.ELU(),
            nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners=True), #, align_corners=True
            nn.Conv2d(32, 16, kernel_size=3, padding = 2),
            nn.ELU(),
            nn.Conv2d(16, image_channels, kernel_size = 3, padding = 2),
            nn.Sigmoid(),
        )
        
        
    def reparameterize(self, mu, logvar):
      
        """std = logvar.mul(0.5).exp_()
        esp = torch.randn(*mu.size())
        z = mu + std * esp"""
        eps = torch.randn(mu.size()).cuda()
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mu)
          
        logq_xz = torch.distributions.MultivariateNormal(mu,  torch.eye(100).cuda())
        q_xz = logq_xz.log_prob(z)
        
        log_p_z = torch.distributions.MultivariateNormal(torch.zeros(100).cuda(), torch.eye(100).cuda())
        p_z = log_p_z.log_prob(z)
                
        return z, q_xz, p_z

    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z, q_xz, p_z = self.reparameterize(mu, logvar)
        return z, q_xz, p_z, mu, logvar

    def forward(self, x):
      
        h = self.encoder(x)
        z, q_xz, p_z, mu, logvar = self.bottleneck(h)
        z_decoder = self.fc3(z)
        decoder = self.decoder(z_decoder)
        
        return decoder, mu, logvar
      
    def imp_sample(self, x, h):
        
        z, q_xz, p_z, mu, logvar = self.bottleneck(h)
        
        z_decoder = self.fc3(z)
        decoder = self.decoder(z_decoder)
        
        log_p_xz = torch.distributions.Bernoulli(decoder.view(decoder.size(0),784))
        p_xz = log_p_xz.log_prob(x.view(x.size(0),784))
        
        return p_xz, q_xz, p_z

        


Using downloaded and verified file: binarized_mnist/binarized_mnist_train.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_valid.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_test.amat


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

model = VAE().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


criterion = nn.BCELoss(reduction = "sum").cuda()

def loss_fn(recon_x, x, mu, logvar):
    ## E[log P(X|z)]
    
    BCE = criterion(recon_x, x)
    KLD = - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp() )
    return BCE + KLD, BCE, KLD

for epoch in range(epochs):
    for idx, (images) in enumerate(train):
        #optimizer.zero_grad()
        images = images.to(device)
        optimizer.zero_grad()
        recon_images, mu, logvar = model(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
        
        loss.backward()
        optimizer.step()
    
    to_print = "Epoch[{}/{}] ELBO: {:.3f} ,BCE: {:.3f} ,KLD: {:.3f}".format(epoch+1, 
                            epochs, -loss/bs, bce/bs, kld/bs)
    print("Training Data : ", to_print)




Training Data :  Epoch[1/20] ELBO: -96.045 ,BCE: 82.138 ,KLD: 13.907
Training Data :  Epoch[2/20] ELBO: -80.444 ,BCE: 66.504 ,KLD: 13.940
Training Data :  Epoch[3/20] ELBO: -73.830 ,BCE: 59.252 ,KLD: 14.578
Training Data :  Epoch[4/20] ELBO: -68.884 ,BCE: 54.017 ,KLD: 14.867
Training Data :  Epoch[5/20] ELBO: -60.572 ,BCE: 46.040 ,KLD: 14.532
Training Data :  Epoch[6/20] ELBO: -66.402 ,BCE: 50.447 ,KLD: 15.955
Training Data :  Epoch[7/20] ELBO: -62.277 ,BCE: 46.363 ,KLD: 15.915
Training Data :  Epoch[8/20] ELBO: -59.959 ,BCE: 44.578 ,KLD: 15.381
Training Data :  Epoch[9/20] ELBO: -61.027 ,BCE: 45.452 ,KLD: 15.575
Training Data :  Epoch[10/20] ELBO: -63.458 ,BCE: 47.295 ,KLD: 16.163
Training Data :  Epoch[11/20] ELBO: -63.981 ,BCE: 47.240 ,KLD: 16.741
Training Data :  Epoch[12/20] ELBO: -58.973 ,BCE: 43.043 ,KLD: 15.931
Training Data :  Epoch[13/20] ELBO: -61.327 ,BCE: 45.105 ,KLD: 16.221
Training Data :  Epoch[14/20] ELBO: -59.745 ,BCE: 43.539 ,KLD: 16.206
Training Data :  Epoch[15/20]

In [8]:
#Importance Sampling Function:

def importance_sampling(model, data):
    #Input as our trained model and data
  
    ip_loss = []
    count = 0
    #Setting the model parameters to no_grad
    for param in model.parameters():
        param.requires_grad = False

    for idx, (images) in enumerate(valid):
        images = images.to(device)
        count += 1
        #Calling model to get the output from encoder
        h = model.encoder(images.view(images.size(0),1,28,28))
        h = h.view(images.size(0),256)

        ip_sample = []

        for k in range(200):
            #Calling model to get output of our densities :
            p_xz, q_xz, p_z = model.imp_sample(images, h)
            #Summing over 784 dimention and the dimention we will have (1XBatch_Size)
            p_xz = torch.sum(p_xz,dim=1)
            #Importance Sampling calculation with LogSumExp trick:
            out = p_xz - q_xz + p_z
            log_weight = out - torch.max(out, 0)[0]
            weight = torch.exp(log_weight)
            weight = weight / torch.sum(weight, 0)
            loss = torch.mean(torch.sum(weight * (p_z + p_xz - q_xz), 0))
            ip_sample.append(loss)

        ip_loss.append(torch.sum(torch.stack(ip_sample))/bs)

    print("Importance Sampling over mini batch -", count,' : ' , torch.stack(ip_loss))

    log_likelihood_estimate = torch.mean(torch.stack(ip_loss))
    print('log-likelihood  estimate :', log_likelihood_estimate)
    
        

In [10]:
#Importance Sampling on Validation Data:
importance_sampling(model, valid)

Importance Sampling over mini batch - 79  :  tensor([-53.0373, -39.6375, -42.0156, -45.4993, -43.2120, -39.7358, -50.2935,
        -52.7791, -49.9681, -47.0958, -52.0168, -43.4983, -44.4474, -45.6429,
        -42.7070, -42.4412, -45.6604, -48.0044, -42.9745, -47.7442, -46.1955,
        -48.5450, -47.9398, -49.5970, -46.7224, -43.6381, -46.3858, -45.7736,
        -46.1605, -44.1194, -42.1705, -48.9098, -45.9725, -43.3371, -42.3633,
        -40.1183, -46.3247, -43.2202, -44.9085, -47.2767, -48.7591, -41.9616,
        -47.2108, -42.1973, -41.6795, -41.4401, -40.9675, -45.4725, -44.3390,
        -49.2932, -48.3470, -45.8876, -35.1705, -44.6851, -48.6321, -47.3355,
        -44.9034, -43.3303, -45.2567, -47.7691, -42.3350, -50.9472, -43.4320,
        -44.1024, -45.1459, -49.2735, -46.8572, -40.4301, -46.7594, -45.4955,
        -44.2671, -43.6105, -42.2646, -46.1439, -49.0317, -44.5678, -43.1247,
        -49.4768, -92.1708], device='cuda:0')
log-likelihood  estimate : tensor(-45.9768, device=

In [11]:
#Importance Sampling on Test Data:
importance_sampling(model, test)

Importance Sampling over mini batch - 79  :  tensor([-53.2666, -39.6112, -42.0173, -45.0696, -42.7885, -39.3289, -50.2789,
        -52.5768, -50.1204, -47.3126, -51.4464, -43.7340, -44.2706, -46.3034,
        -42.7663, -42.2453, -45.8034, -47.9994, -43.1152, -47.3775, -46.3774,
        -48.9510, -47.6278, -50.0038, -46.5382, -43.5363, -46.5026, -45.6286,
        -45.8510, -44.3316, -42.0889, -48.5721, -45.6009, -43.2220, -41.8239,
        -39.9657, -46.6021, -43.0379, -44.8822, -47.2276, -49.2708, -41.3384,
        -48.0132, -42.6075, -41.1648, -41.6543, -41.3515, -45.6582, -44.9995,
        -48.8655, -48.8445, -45.2723, -34.1606, -45.0158, -48.3008, -47.6436,
        -45.0193, -43.4447, -45.0337, -47.7312, -42.3400, -50.9683, -43.3457,
        -43.8593, -44.5975, -49.5735, -46.5705, -40.6795, -47.0243, -45.8740,
        -44.4704, -43.6746, -42.3586, -45.8282, -49.6287, -45.2142, -42.7317,
        -49.6860, -91.7219], device='cuda:0')
log-likelihood  estimate : tensor(-45.9663, device=

In [12]:
#Running Validation :
epochs = 1
bs = 10000

train, valid, test = get_data_loader("binarized_mnist", bs)

len(valid)



Using downloaded and verified file: binarized_mnist/binarized_mnist_train.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_valid.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_test.amat


1

In [13]:
for epoch in range(epochs):
    for idx, (images) in enumerate(valid):
        images = images.to(device)
        
        recon_images, mu, logvar = model(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
    
    to_print = "Epoch[{}/{}] ELBO: {:.3f} ,BCE: {:.3f} ,KLD: {:.3f}".format(epoch+1, 
                            epochs, -loss/bs, -bce/bs, kld/bs)
    print("Validation Data : ", to_print)
    

Validation Data :  Epoch[1/1] ELBO: -95.835 ,BCE: -69.644 ,KLD: 26.191


In [14]:
for epoch in range(epochs):
    for idx, (images) in enumerate(test):
        images = images.to(device)
        
        recon_images, mu, logvar = model(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
    
    to_print = "Epoch[{}/{}] ELBO: {:.3f} ,BCE: {:.3f} ,KLD: {:.3f}".format(epoch+1, 
                            epochs, -loss/bs, -bce/bs, kld/bs)
    print("Test Data : ", to_print)
    

Test Data :  Epoch[1/1] ELBO: -95.108 ,BCE: -69.049 ,KLD: 26.060
