<a href="https://colab.research.google.com/github/mchivuku/csb659-project/blob/master/VAE_%2B_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%%capture
!pip install tqdm six

In [0]:
from google.colab import drive

drive.mount("/content/drive")


%cd /content/drive/My\ Drive/Masters-DS/CSCI-B659/project/examples/vae/
%ls

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive
/content/drive/My Drive/Masters-DS/CSCI-B659/project/examples/vae
[0m[01;34mMNIST[0m/  [01;34mresults[0m/


## Classification Accuracy of VAE

adopted - https://github.com/wohlert/semi-supervised-pytorch/blob/master/examples/notebooks/Deep%20Generative%20Model.ipynb

https://lirnli.wordpress.com/2017/09/14/latent-layers-beyond-the-variational-autoencoder-vae/

http://cs231n.stanford.edu/reports/2017/pdfs/3.pdf

In [0]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image


params = {
    "batch_size":128,
    "epochs" : 10,
    "log_interval":10
    
}

torch.manual_seed(5)

print('Torch', torch.__version__, 'CUDA', torch.version.cuda)
print('Device:', torch.device('cuda:0'))
print(torch.cuda.is_available())

is_cuda = torch.cuda.is_available()
device = torch.device ( "cuda:0" if torch.cuda.is_available () else "cpu" )




Torch 1.0.1.post2 CUDA 10.0.130
Device: cuda:0
True


In [0]:
class RunningAverage ():
    """A simple class that maintains the running average of a quantity
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """

    def __init__( self ):
        self.steps = 0
        self.total = 0

    def update( self, val ):
        self.total += val
        self.steps += 1
    
    def reset(self):
        self.steps = 0.
        self.total = 0.
        
        
        
    def __call__( self ):
        return self.total / float ( self.steps )

In [0]:
## Data Loaders
kwargs = {'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('MNIST/data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=params.get("batch_size"), shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('MNIST/data', train=False, transform=transforms.ToTensor()),
    batch_size=params.get("batch_size"), shuffle=True, **kwargs)

In [0]:
from torch.nn import init

"""
Flatten input
"""
def num_flat_features(x):
  size = x.size()[1:] # all dimensions except the batch dimension
  num_features = 1
  for s in size:
    num_features *=s
  return num_features
  
## Classifier
class Classifier(nn.Module):
  def __init__(self,dims):
    """
    Single hidden layer with softmax
    """
    super(Classifier,self).__init__()
    [x_dim, h_dim, y_dim] = dims
    
    self.fc1 = nn.Linear(x_dim, h_dim)
    self.fc2 = nn.Linear(h_dim, y_dim)
    
  def forward(self,x):
    ## flatten x input
    x = x.view(-1,num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.softmax(self.fc2(x),dim=-1)
    
    return x
  
  
  
  
"""
Gaussian Sample Layer
"""
class GaussianSample(nn.Module):
  def __init__(self,in_features, out_features):
    super(GaussianSample, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    
    self.mu = nn.Linear(in_features, out_features)
    self.log_var = nn.Linear(in_features, out_features)
    
  def forward(self,x):
    mu = self.mu(x)
    log_var = F.softplus(self.log_var(x))
    
    return self.reparametrize(mu,log_var), mu, log_var
    
  def reparametrize(self, mu, log_var):
    epsilon = Variable(torch.randn(mu.size()), requires_grad=False)

    if mu.is_cuda:
      epsilon = epsilon.cuda()
    # log_std = 0.5 * log_var
    # std = exp(log_std)
    std = log_var.mul(0.5).exp_()

    # z = std * epsilon + mu
    z = mu.addcmul(std, epsilon)
    return z
      

"""
VAE Encoder
"""
class Encoder(nn.Module):
  def __init__(self,dims):
    super(Encoder,self).__init__()
    [x_dim, h_dim, z_dim] = dims
    ##linear layers
    
    neurons = [x_dim, *h_dim]
    print("Neurons",neurons)
    linear_layers = [nn.Linear(neurons[i-1],neurons[i]) for i in range(1,len(neurons))]
    
    self.hidden=nn.ModuleList(linear_layers)
    
    self.sample = GaussianSample(h_dim[-1],z_dim)
    
    
  def forward(self,x):
    
    x = x.view(-1,784)
    
    for layer in self.hidden:
      x = F.relu(layer(x))
      
    
    return self.sample(x)
  
"""
VAE Decoder
"""
class Decoder(nn.Module):
  def __init__(self,dims):
    super(Decoder,self).__init__()
    [z_dim, h_dim, x_dim] = dims
    ##linear layers
    neurons = [z_dim, *h_dim]
    
    linear_layers = [nn.Linear(neurons[i-1],neurons[i]) for i in range(1,len(neurons))]
    
    self.hidden=nn.ModuleList(linear_layers)
    
    self.reconstruction = nn.Linear(h_dim[-1],x_dim)
    self.output_activation = torch.sigmoid
    
    
  def forward(self,x):
    #x = x.view(-1,num_flat_features(x))
    for layer in self.hidden:
      x = F.relu(layer(x))
    return self.output_activation(self.reconstruction(x))
  


"""
VAE object
"""
class VAE(nn.Module):
  def __init__(self,dims):
    super(VAE,self).__init__()
    
    [x_dim,z_dim,h_dim] = dims
    self.z_dim = z_dim
    self.h_dim = h_dim
    self.x_dim = x_dim
   
    ## xaview initialization for weights
    for m in self.modules():
      if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
          m.bias.data.zero_()
    
 
  """
  Forward
  """
  def forward(self,x):
    z, z_mu, z_log_var = self.encoder(x)
    return self.decoder(z), z_mu, z_log_var
  
  
  
"""
Deep Generative Model 
"""
class DeepGenerativeModel(VAE):
    def __init__(self, dims):
        """
        M2 code replication from the paper
        'Semi-Supervised Learning with Deep Generative Models'
        (Kingma 2014) in PyTorch.
        The "Generative semi-supervised model" is a probabilistic
        model that incorporates label information in both
        inference and generation.
        Initialise a new generative model
        :param dims: dimensions of x, y, z and hidden layers.
        """
        [x_dim, self.y_dim, z_dim, h_dim] = dims
        super(DeepGenerativeModel, self).__init__([784, z_dim, h_dim])
        
        

        self.encoder = Encoder([x_dim, h_dim, z_dim])

        self.decoder = Decoder([z_dim, list(reversed(h_dim)), x_dim])
        self.classifier = Classifier([x_dim, h_dim[0], self.y_dim])

        for m in self.modules():
            if isinstance(m, nn.Linear):
              init.xavier_normal_(m.weight.data)
              if m.bias is not None:
                m.bias.data.zero_()

    """
    Compute KLD on gaussian sample - inference
    """
    def _kld(self,x,q_param):
      (mu, log_var) = q_param
      log_pdf = - 0.5 * math.log(2 * math.pi) - log_var / 2 - (x - mu)**2 / (2 * torch.exp(log_var))
      return torch.sum(log_pdf, dim=-1)
      
    def forward(self, x, y):
        # Add label and data and generate latent variable
        
        z, z_mu, z_log_var = self.encoder(x)

        self.kl_divergence = self._kld(z, (z_mu, z_log_var))

        # Reconstruct data point from latent data and label
        x_mu = self.decoder(z)

        return x_mu

    def classify(self, x):
        logits = self.classifier(x)
        return logits

    def sample(self, z, y):
        """
        Samples from the Decoder to generate an x.
        :param z: latent normal variable
        :param y: label (one-hot encoded)
        :return: x
        """
        y = y.float()
        x = self.decoder(torch.cat([z, y], dim=1))
        return x


"""
binary cross entropy loss
"""
def binary_cross_entropy(r, x):
    return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1)

  
y_dim = 10
z_dim = 32
h_dim = [256, 128]

model = DeepGenerativeModel([784, y_dim, z_dim, h_dim]).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4, betas=(0.9,0.999))
criterion = nn.CrossEntropyLoss()
model

Neurons [784, 256, 128]


DeepGenerativeModel(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=128, out_features=32, bias=True)
      (log_var): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (decoder): Decoder(
    (hidden): ModuleList(
      (0): Linear(in_features=32, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
    (reconstruction): Linear(in_features=256, out_features=784, bias=True)
  )
  (classifier): Classifier(
    (fc1): Linear(in_features=784, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [0]:
## Loss function
def log_standard_categorical(p):
  """
    Calculates the cross entropy between a (one-hot) categorical vector
    and a standard (uniform) categorical distribution.
    :param p: one-hot categorical distribution
    :return: H(p, u)
  """
  # Uniform prior over y
  prior = F.softmax(torch.ones_like(p, dtype=torch.int64))
  prior.requires_grad = False
  cross_entropy = -torch.sum(p* torch.log(prior + 1e-8), dim=1)
  
  return cross_entropy

"""
LogSumExp an approximation for the sum in log-domain
"""
def log_sum_exp(tensor, dim=-1, sum_op = torch.sum):
  max, _ = torch.max(tensor, dim = dim, keepdim = True)
  return torch.log(sum_op(torch.exp(tensor - max), dim = dim, keepdim = True) + 1e-8) + max


"""
Importance Weighted Sampler - [Burda 2015] to be used in conjunction to get better estimate of
Stochastic variational inference
"""  
class ImportanceWeightedSampler(object):
  def __init__(self,mc=1,iw=1):
    """
    sampler
    :param mc: number of monte carlo samples
    : param iw: number of importance weighted samples
    """
    self.mc = mc
    self.iw = iw
    
  def resample(self,x):
    return x.repeat(self.mc * self.iw, 1)
  
  def __call__(self, elbo):
    elbo = elbo.view(self.mc, self.iw, -1)
    elbo = torch.mean(log_sum_exp(elbo, dim=1, sum_op=torch.mean), dim=0)
    return elbo.view(-1)


In [0]:
"""
Train
"""
from itertools import repeat
from torch.autograd import Variable
import math

sampler = ImportanceWeightedSampler(mc=1, iw=1)
beta = repeat(1)
for epoch in range(10):
  model.train()
  train_loss = 0
  accuracy = 0
  for batch_idx, (data, labels) in enumerate(train_loader):
    x = data.to(device)
    y = labels.to(device)
    
    
    
    optimizer.zero_grad()
    batch_size = data.size(0)
    
    ## increase sampling dimension
    #xs = sampler.resample(x)
    #ys = sampler.resample(y)
    
    reconstruction= model(x, y)
    
   
    #p(x|y,z)
    likelihood = -F.binary_cross_entropy(reconstruction, x.view(-1,784))
    
    
    #p(y)
    #prior = -log_standard_categorical(y)
    
    #Equivalent to -L(x,y)
    elbo  = likelihood  - next(beta) *model.kl_divergence
    L = sampler(elbo)
    
    loss = torch.mean(L)
   
    
    logits = model.classify(x)
    # Regular cross entropy
    classification_loss = criterion(logits, y) 
    total_loss = loss +  classification_loss
    total_loss.backward()
    
    train_loss += total_loss.item()
    _, idx = logits.topk(1,dim=1)
    accuracy+=  torch.sum(idx.view(batch_size) == y).item()
    
    optimizer.step()
    
    
    
  if epoch % 1 == 0:
    m = len(train_loader)
    print("Epoch: {}".format(epoch))
    print("[Train]\t\t J_a: {:.2f}, accuracy: {:.2f}".format(train_loss / m, accuracy / m))

  

Epoch: 0
[Train]		 J_a: 23.80, accuracy: 14.38
Epoch: 1
[Train]		 J_a: 23.79, accuracy: 14.38
Epoch: 2
[Train]		 J_a: 23.79, accuracy: 14.38
Epoch: 3
[Train]		 J_a: 23.79, accuracy: 14.38
Epoch: 4
[Train]		 J_a: 23.82, accuracy: 14.38
Epoch: 5
[Train]		 J_a: 23.81, accuracy: 14.38
Epoch: 6
[Train]		 J_a: 23.78, accuracy: 14.38
Epoch: 7
[Train]		 J_a: 23.79, accuracy: 14.38
Epoch: 8
[Train]		 J_a: 23.80, accuracy: 14.38
Epoch: 9
[Train]		 J_a: 23.83, accuracy: 14.38


In [0]:
"""
Test
"""
def test(epoch):
    model.eval()
    test_loss = 0
    test_accuracy = RunningAverage()
    with torch.no_grad():
        for i, (data,labels) in enumerate(test_loader):
            data = data.to(device)
            labels = labels.to(device)
            recon_batch, mu, logvar = model(data)
            batch_size = data.size(0)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            ## accuracy
            _, idx = recon_batch.topk(1,dim=1)
            test_accuracy.update(torch.sum(idx.view(batch_size) == labels).item())
        
            
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(params.get("batch_size"), 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         './results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}, test accuracy: {:.4f}'.format(test_loss,test_accuracy()))
    return (test_loss,test_accuracy)

In [0]:
train_losses = []
test_losses = []
train_acc_list = []
test_acc_list = []
for epoch in range(1, 2):
        trainloss,train_acc = train(epoch)
        testloss,test_acc = test(epoch)
        train_losses.append(trainloss)
        test_losses.append(testloss)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')
            
            
            
            


tensor([[407],
        [408],
        [408],
        [407],
        [408],
        [434],
        [407],
        [407],
        [408],
        [407],
        [407],
        [407],
        [407],
        [407],
        [408],
        [434],
        [407],
        [407],
        [408],
        [407],
        [408],
        [407],
        [407],
        [407],
        [407],
        [407],
        [408],
        [407],
        [407],
        [408],
        [408],
        [407],
        [408],
        [407],
        [408],
        [434],
        [407],
        [407],
        [408],
        [408],
        [407],
        [407],
        [407],
        [407],
        [407],
        [407],
        [408],
        [407],
        [407],
        [407],
        [407],
        [408],
        [407],
        [408],
        [407],
        [407],
        [408],
        [407],
        [407],
        [407],
        [407],
        [408],
        [407],
        [407],
        [407],
        [407],
        [4

In [0]:
F.softmax(torch.ones_like(y,dtype=torch.float))

  """Entry point for launching an IPython kernel.


tensor([0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 

In [0]:
import os
os.makedirs("./results")