In [1]:
import torch
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms
from torchnet.meter import AverageValueMeter

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import importlib
import argparse

from tensorflow.examples.tutorials.mnist import input_data

In [2]:
%cd ../datasets

/Users/hoangnguyen/Documents/Github/generative-model/datasets


In [3]:
import mnist; importlib.reload(mnist)

<module 'mnist' from '/Users/hoangnguyen/Documents/Github/generative-model/datasets/mnist.py'>

# Load data

In [4]:
parser = {
    'batch_size': 64,
    'no_cuda': True,
    'epochs': 20
}
args = argparse.Namespace(**parser)

In [5]:
mb_size = args.batch_size
epochs = args.epochs
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
cnt = 0
lr = 1e-3
noise_factor = .25

In [6]:
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [7]:
train_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=True,
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=False, 
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True)

In [8]:
#test DataLoader
for x, y in train_loader:
    print(x)
    break


(0 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
     ⋮ 

(1 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
     ⋮ 

(2 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
 

In [9]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

# Q(z|x): Encode 

In [10]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [11]:
Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

In [12]:
Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [13]:
def Q(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_logvar = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_logvar

In [14]:
def sample_z(z_mu, z_logvar):
    #eps = Variable(torch.randn(mb_size, Z_dim)) #randomize according to normal distribution
    #The above can lead to bug because mb_size may be different from the sampling
    eps = Variable(torch.randn(z_mu.size(0), Z_dim))
    return z_mu + torch.exp(z_logvar / 2) * eps

# P(X|z): Decode

In [15]:
Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

In [16]:
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [17]:
def P(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

# Training

In [18]:
def onehot(y):
    y_array = np.zeros(shape=[y.size(0), 10], dtype='float32')
    for index, ele in enumerate(y):
        y_array[index][ele] = 1
    return torch.from_numpy(y_array)

In [19]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

In [20]:
optimizer = optim.Adam(params, lr=lr)

In [21]:
def train(epoch):
    losses = AverageValueMeter()
    
    for it, (X, _) in enumerate(train_loader):
        X = Variable(X.view(-1, 784))
        #c = Variable(onehot(c))
        
        # Add noise
        X_noise = X + noise_factor * Variable(torch.randn(X.size()))
        X_noise.data.clamp_(0., 1.)
        
        #forward   
        z_mu, z_logvar = Q(X_noise)    #not X
        z = sample_z(z_mu, z_logvar)
        X_sample = P(z)
        
        #Loss
        recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False)
        kl_loss = torch.sum(0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1))
        loss = recon_loss + kl_loss
        
        #Initialize zero buffers
        for p in params:
            p.grad.data.zero_()
        
        #Backward
        loss.backward()
        
        #Update
        optimizer.step()
        
        losses.add(loss.data[0], X_noise.size(0))
        
        
    #Print loss
    print('====> Epoch: {0}/{1}\tLoss: {2:.4f}'.format(epoch, epochs, losses.value()[0]))

In [22]:
def plot(samples, epoch):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    if not os.path.exists('../conditional_vae/out/'):
        os.makedirs('out/')
    
    fileName = '../denoising_vae/out/' + str(epoch).zfill(3)
    plt.savefig(fileName, bbox_inches='tight')
    #plt.show()
    plt.close(fig)

In [23]:
for epoch in range(1, args.epochs + 1):
    train(epoch)
    z = torch.randn(mb_size, Z_dim)
    samples = P(Variable(z)).data.numpy()[:16]
    plot(samples, epoch)

====> Epoch: 1/20	Loss: 192.5622
====> Epoch: 2/20	Loss: 146.1743
====> Epoch: 3/20	Loss: 134.9121
====> Epoch: 4/20	Loss: 127.2705
====> Epoch: 5/20	Loss: 122.8648
====> Epoch: 6/20	Loss: 120.5542
====> Epoch: 7/20	Loss: 119.0980
====> Epoch: 8/20	Loss: 118.1528
====> Epoch: 9/20	Loss: 117.5100
====> Epoch: 10/20	Loss: 117.0692
====> Epoch: 11/20	Loss: 116.6568
====> Epoch: 12/20	Loss: 116.4453
====> Epoch: 13/20	Loss: 116.2598
====> Epoch: 14/20	Loss: 116.0251
====> Epoch: 15/20	Loss: 115.8891
====> Epoch: 16/20	Loss: 115.6934
====> Epoch: 17/20	Loss: 115.5492
====> Epoch: 18/20	Loss: 115.4122
====> Epoch: 19/20	Loss: 115.3169
====> Epoch: 20/20	Loss: 115.1619
