<a href="https://colab.research.google.com/github/mchivuku/csb659-project/blob/master/VAE_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%%capture
!pip install tqdm six


!pip install bokeh
!pip install tensorboard
!pip install livelossplot

!pip install tensorboard

# VAE-GAN -

Autoencoding beyond pixels using a learned similarity metric

https://arxiv.org/pdf/1512.09300.pdf

Code is adapted: 
https://github.com/pravn


In [0]:
from google.colab import drive

drive.mount("/content/drive")


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
%cd /content/drive/My\ Drive/Masters-DS/CSCI-B659/project/examples/vae-gan

/content/drive/My Drive/Masters-DS/CSCI-B659/project/examples/vae-gan


In [0]:
## Imports

from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision.utils as utils

import os
import matplotlib.pyplot as plt
## Plotting library

from bokeh.plotting import figure
from bokeh.io import show
from bokeh.models import LinearAxis, Range1d
from livelossplot import PlotLosses


plt.style.use('ggplot')

print('Torch', torch.__version__, 'CUDA', torch.version.cuda)
print('Device:', torch.device('cuda:0'))
print(torch.cuda.is_available())

is_cuda = torch.cuda.is_available()
device = torch.device ( "cuda:0" if torch.cuda.is_available () else "cpu" )


Torch 1.0.1.post2 CUDA 10.0.130
Device: cuda:0
True


In [0]:
## Data loader
batch_size =  100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../vae/MNIST/data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../vae/MNIST/data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [0]:
criterion = nn.BCELoss()

In [0]:
## Define Models
# Generator (VAE), Discriminator (D), Aux
"""
VAE Model
"""
class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()

    self.fc1 = nn.Linear(784, 400)
    self.fc21 = nn.Linear(400, 20)
    self.fc22 = nn.Linear(400, 20)
    self.fc3 = nn.Linear(20, 400)
    self.fc4 = nn.Linear(400, 784)

    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

    self.mu_ = nn.Sequential(
        #28x28->12x12
        nn.Conv2d(1,8,5,2,0,bias=False),
        nn.BatchNorm2d(8),
        nn.ReLU(True),
        #12x12->4x4
        nn.Conv2d(8,64,5,2,0,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        #4x4->1x1: 20,1,1
        nn.Conv2d(64,20,4,1,0,bias=False),
        nn.ReLU(True)
        )


    self.logsigma_ = nn.Sequential(
        #28x28->12x12
        nn.Conv2d(1,8,5,2,0,bias=False),
        nn.BatchNorm2d(8),
        nn.ReLU(True),
        #12x12->4x4
        nn.Conv2d(8,64,5,2,0,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        #4x4->1x1: 20,1,1
        nn.Conv2d(64,20,4,1,0,bias=False),
        nn.ReLU(True)
        )


    self.dec_ = nn.Sequential(
        #1x1->4x4
        nn.ConvTranspose2d(20,20*8,4,1,0,bias=False),  #(ic,oc,kernel,stride,padding)
        nn.BatchNorm2d(20*8), 
        nn.ReLU(True),
        nn.ConvTranspose2d(20*8,20*16,4,2,1,bias=False), #4x4->8x8
        nn.BatchNorm2d(20*16),
        nn.ReLU(True),
        nn.ConvTranspose2d(20*16,20*32,4,2,1,bias=False), #8x8->16x16
        nn.BatchNorm2d(20*32),
        nn.ReLU(True),
        nn.ConvTranspose2d(20*32,1,2,2,2,bias=False), #16x16->28x28
        nn.Sigmoid()
        )
  def encode(self, x):
    h1 = self.relu(self.fc1(x))
    return self.fc21(h1), self.fc22(h1)
  
  def encode_new(self,x):
    return self.mu_(x), self.logsigma_(x)

  def reparameterize(self, mu, logvar):
    if self.training:
      std = logvar.mul(0.5).exp_()
      eps = Variable(std.data.new(std.size()).normal_())
      return eps.mul(std).add_(mu)
    else:
      return mu


  def decode_new(self,z):
    z = z.view(-1,z.size(1),1,1)
    return(self.dec_(z))
  
  def decode(self, z):
    z = z.view(-1,20)
    h3 = self.relu(self.fc3(z))
    return self.sigmoid(self.fc4(h3))

  def dec_params(self):
    return self.fc3, self.fc4

  def return_weights(self):
    return self.fc3.weight, self.fc4.weight

  def forward(self, x):
    mu, logvar = self.encode_new(x.view(-1, 1,28,28))
    z = self.reparameterize(mu, logvar)
    return mu,logvar


class Aux(nn.Module):
    def __init__(self):
        super(Aux,self).__init__()

        self.fc3 = nn.Linear(20,400)
        self.fc4 = nn.Linear(400,784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def decode(self,z):
        z = z.view(-1,20)
        h3 = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(h3))
    
    def reparameterize(self, mu, logvar):
        if self.training:
          std = logvar.mul(0.5).exp_()
          eps = Variable(std.data.new(std.size()).normal_())
          return eps
        else:
          return mu

    def dec_params(self):
        return self.fc3,self.fc4

    def return_weights(self):
        return self.fc3.weight, self.fc4.weight

    
    def forward(self,z):
        #self.fc3.weight = fc3_weight
        #self.fc4.weight = fc4_weight
        
        #z = self.reparameterize(mu,logvar)
        #other.fc3,other.fc4 = self.dec_params()
        #return self.decode(z).view(-1,28,28)
        return self.decode(z)


    

class NetD(nn.Module):
    def __init__(self):
        super(NetD, self).__init__()

        self.D_l = nn.Sequential(
        #state size 1x28x28
            #28x28->16x16
            nn.Conv2d(1,8,2,2,2,bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #16x16->8x8
            nn.Conv2d(8,16,4,2,1,bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            #8x8->4x4
            nn.Conv2d(16,32,4,2,1,bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True)
            #4x4->1x1
            #nn.Conv2d(32,1,4,1,0),
            #nn.Sigmoid()
            )

        self.main = nn.Sequential(
            #4x4->1x1
            nn.Conv2d(32,1,4,1,0),
            nn.Sigmoid()
            )


    def forward(self,x):
        d_l = self.D_l(x.view(-1,1,28,28))
        o = self.main(d_l)
        #o = self.main(x.view(-1,784))
        return d_l, o

def loss_function(recon_x, x, mu, logvar,bsz=100):
    #BCE = F.binary_cross_entropy(recon_x.view(-1,784), x.view(-1, 784))
    #MSE = F.mse_loss(recon_x.view(-1,784), x.view(-1,784))
    MSE = F.mse_loss(recon_x,x)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Normalise by same number of elements as in reconstruction
    KLD /= bsz * 784

    return MSE + KLD
    

In [0]:
netG = VAE().cuda()
netD = NetD().cuda()
aux = Aux().cuda()

In [0]:
input = torch.FloatTensor(batch_size,28,28)
label = torch.FloatTensor(batch_size)
real_label=1
fake_label=0
optimizerD = optim.Adam(netD.parameters(), lr=1e-4)
optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizer_aux = optim.Adam(aux.parameters(), lr=1e-4)

## Training

In [0]:
num_epochs = 1000
for epoch in range(num_epochs):
  for i, (data,_) in enumerate(train_loader):
    gamma = 1.0
    real_cpu  = data
    
    
    input.resize_as_(real_cpu).copy_(real_cpu)
    label.resize_(batch_size).fill_(real_label)
    
    
    
    inputv = Variable(input)
    labelv = Variable(label.cuda())
    
    labelv = labelv.to(device)
    inputv = inputv.to(device)
    ## zero_grad
    netD.zero_grad()
    netG.zero_grad()
    aux.zero_grad()
    
    
    ## 
    mu, logvar = netG(inputv)
    
    ## Reparameterize Z
    std = logvar.mul(0.5).exp_()
    eps = Variable(std.data.new(std.size()).normal_())
    z=eps.mul(std).add_(mu)
    
    ## Auxiliary network
    fake = aux(z)
    
    ## NetD
    x_l_tilde, output_fake = netD(fake)
    x_l, output_real = netD(inputv)
    
    output_fake = output_fake.cuda()
    output_real = output_real.cuda()
    
    L_GAN_real = criterion(output_real,labelv)
    L_GAN_real.backward(retain_graph=True)
    
    
    labelv = Variable(label.fill_(fake_label).cuda())
    
    L_GAN_fake = criterion(output_fake,labelv)
    
    L_GAN_fake.backward(retain_graph = True)
    
    z_p = Variable(std.data.new(std.size()).normal_())
    fake_aux = aux(z_p)
    x_l_aux, output_aux = netD(fake_aux)
    L_GAN_aux = criterion(output_aux,labelv)
    L_GAN_aux.backward(retain_graph=True)
    optimizerD.step()
           
        
    #get weights of netG and use in aux
    aux.zero_grad()
    labelv=Variable(label.fill_(real_label))
    
    labelv = labelv.to(device)
    ## Loss functions
    L_dec_vae = gamma*loss_function(x_l_tilde,x_l,mu,logvar)
    L_dec_fake = criterion(output_fake,labelv)
    L_dec_aux  = criterion(output_aux,labelv)
    L_dec_vae.backward(retain_graph=True)
    L_dec_fake.backward(retain_graph=True)
    L_dec_aux.backward(retain_graph=True)
    optimizer_aux.step()
    
    #encoder loss 
    netG.zero_grad()
    L_enc = loss_function(x_l_tilde, x_l,mu,logvar)
    L_enc.backward()
    optimizerG.step()

    if i % 100 == 0:
      #print('real_cpu.size()', real_cpu.size())
      utils.save_image(real_cpu,
                            './results/real_samples.png',
                                normalize=True)
      utils.save_image(fake.data.view(-1,1,28,28),
                                './results/fake_samples.png',
                                normalize=True)

    
    
    
    
    
    

  "Please ensure they have the same size.".format(target.size(), input.size()))


KeyboardInterrupt: ignored

In [0]:
os.makedirs("./results")