In [1]:
import torch
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import importlib
import argparse

from tensorflow.examples.tutorials.mnist import input_data

In [2]:
%pwd

'/Users/hoangnguyen/Documents/Github/generative-model/conditional_vae'

In [3]:
%cd ../datasets

/Users/hoangnguyen/Documents/Github/generative-model/datasets


In [4]:
import mnist; importlib.reload(mnist)

<module 'mnist' from '/Users/hoangnguyen/Documents/Github/generative-model/datasets/mnist.py'>

# Load data

- X: input
- c: label
- z: latent variable

In [5]:
#or link to mnist.py
#import mnist; importlib.reload(mnist)

In [6]:
parser = {
    'batch_size': 1,
    'no_cuda': True,
    'epochs': 10
}
args = argparse.Namespace(**parser)

In [7]:
mb_size = args.batch_size
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
cnt = 0
lr = 1e-3

In [8]:
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [9]:
args.cuda

False

In [10]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

train_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=True,
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=False, 
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs)

In [11]:
label = 0
#DataLoader returns a batch size of X and y
for X, c in train_loader:
    print(X)
    print(c)
    label = c
    break


(0 ,0 ,.,.) = 

Columns 0 to 8 
   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000 

In [12]:
#Initial weights randomly
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

# Q(z|X): Encoder

In [13]:
#X --> hidden layer: Initialize randomly
Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [14]:
#hidden --> z: Initialize randomly
Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

In [15]:
#hidden --> z: Initialize randomly
Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [16]:
#@: support for python 3.5 - matmul
#Q: encoder
#Calculate z_mu and z_logvar
def Q(X, c):
    inputs = torch.cat([X, c], 1)
    h = F.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1))
    #input.size(): 64x794 
    #Wxh.size(): 794x128
    #repeat(inputs.size(0), 1)): 64x1
    #bxh.repeat: add bxh (1x128) at each row
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_logvar = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_logvar

In [17]:
def sample_z(z_mu, z_logvar):
    #eps = Variable(torch.randn(mb_size, Z_dim)) #randomize according to normal distribution
    #The above can lead to bug because mb_size may be different from the sampling
    eps = Variable(torch.randn(z_mu.size(0), Z_dim))
    return z_mu + torch.exp(z_logvar / 2) * eps

# P(X|z): Decoder

In [18]:
Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

In [19]:
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [20]:
def P(z, c):
    inputs = torch.cat([z, c], 1)
    h = F.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

# Training

In [21]:
def onehot(y):
    y_array = np.zeros(shape=[y.size(0), 10], dtype='float32')
    for index, ele in enumerate(y):
        y_array[index][ele] = 1
    return torch.from_numpy(y_array)

In [22]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

In [23]:
optimizer = optim.Adam(params, lr=lr)

In [30]:
def train(epoch):
        #flatten X: tensor 28x28 --> tensor 1x784
        #one-hot c
    for it, (X, c) in enumerate(train_loader):
        #If X is a numpy array, need to transform X from array to tensor to pass to Variable()
        #X = Variable(torch.from_numpy(X))
        
        X = Variable(X.view(-1, 784))
        c = Variable(onehot(c))
        #c = Variable(torch.from_numpy(onehot(c).astype('float32')))
        
        #forward   
        z_mu, z_logvar = Q(X, c)
        z = sample_z(z_mu, z_logvar)
        X_sample = P(z, c)
        
        #Loss
        #recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
        #kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1))
        
        recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False)
        kl_loss = torch.sum(0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1))
        loss = recon_loss + kl_loss
       
    
        #Initialize zero buffers
        for p in params:
            p.grad.data.zero_()

    
        #Backward
        loss.backward()
        
        #Update
        optimizer.step()
        
        
        # Print and plot every now and then
        if it % 1000 == 0:
            print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0]))

            c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
            c[:, np.random.randint(0, 10)] = 1.
            c = Variable(torch.from_numpy(c))
            z = Variable(torch.randn(mb_size, Z_dim))
            samples = P(z, c).data.numpy()[:16]

            fig = plt.figure(figsize=(4, 4))
            gs = gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

            #if not os.path.exists('../conditional_vae/out/'):
            #    os.makedirs('out/')

            plt.savefig('../conditional_vae/out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight') #to save
            #plt.show()
            #cnt += 1
            #plt.close(fig)
    #Print
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, loss / len(train_loader.dataset)))

In [31]:
%pwd

'/Users/hoangnguyen/Documents/Github/generative-model/datasets'

In [32]:
#for epoch in range(1, args.epochs + 1)
train(1)

Iter-0; Loss: 661.3
Iter-1000; Loss: 141.3
Iter-2000; Loss: 195.4
Iter-3000; Loss: 206.2
Iter-4000; Loss: 218.8
Iter-5000; Loss: 191.5


KeyboardInterrupt: 