In [28]:
import torch
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import importlib
import argparse

from tensorflow.examples.tutorials.mnist import input_data

# Load data

In [29]:
mnist = input_data.read_data_sets('../variational-auto-encoder/data', one_hot=True)

Extracting ../variational-auto-encoder/data/train-images-idx3-ubyte.gz
Extracting ../variational-auto-encoder/data/train-labels-idx1-ubyte.gz
Extracting ../variational-auto-encoder/data/t10k-images-idx3-ubyte.gz
Extracting ../variational-auto-encoder/data/t10k-labels-idx1-ubyte.gz


In [30]:
#or link to mnist.py
#import mnist; importlib.reload(mnist)

In [32]:
parser = {
    'batch_size': 1,
    'no_cuda': True,
    'epochs': 10
}
args = argparse.Namespace(**parser)

In [33]:
mb_size = args.batch_size
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0
lr = 1e-3

In [34]:
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [35]:
args.cuda

False

In [36]:
mnist

Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x10401b668>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1213d86a0>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1213d80b8>)

In [37]:
#Z_dim = 100
Z_dim = 50
X_dim = mnist.train.images.shape[1] # 28x28
y_dim = mnist.train.labels.shape[1] # 10
h_dim = 128
cnt = 0
lr = 1e-3

In [38]:
X_dim

784

In [39]:
y_dim

10

In [40]:
#Initial weights randomly
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

# Q(z|X): Encoder

In [41]:
Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [42]:
Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

In [43]:
Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)

In [44]:
#@: support for python 3.5 - matmul
#Q: encoder
def Q(X, c):
    inputs = torch.cat([X, c], 1)
    h = F.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1))
    #input.size(): 64x794 
    #Wxh.size(): 794x128
    #repeat(inputs.size(0), 1)): 64x1
    #bxh.repeat: add bxh (1x128) at each row
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_logvar = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_logvar

In [45]:
def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim)) #randomize according to normal distribution
    return mu + torch.exp(log_var / 2) * eps

# P(X|z): Decoder

In [46]:
Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

In [47]:
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [48]:
def P(z, c):
    inputs = torch.cat([z, c], 1)
    h = F.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

# Training

In [49]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

In [50]:
solver = optim.Adam(params, lr=lr)

In [52]:
for it in range(100000):
    X, c = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))
    c = Variable(torch.from_numpy(c.astype('float32')))

    # Forward
    z_mu, z_logvar = Q(X, c)
    z = sample_z(z_mu, z_logvar)
    X_sample = P(z, c)

    # Loss
    #recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    #kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1))
    recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False)
    kl_loss = torch.sum(0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping
    for p in params:
        p.grad.data.zero_()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0]))

        c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
        c[:, np.random.randint(0, 10)] = 1.
        c = Variable(torch.from_numpy(c))
        z = Variable(torch.randn(mb_size, Z_dim))
        samples = P(z, c).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; Loss: 668.1
Iter-1000; Loss: 214.1
Iter-2000; Loss: 152.3
Iter-3000; Loss: 183.5
Iter-4000; Loss: 105.9
Iter-5000; Loss: 139.9
Iter-6000; Loss: 116.3
Iter-7000; Loss: 133.1
Iter-8000; Loss: 88.0
Iter-9000; Loss: 107.7
Iter-10000; Loss: 59.21
Iter-11000; Loss: 160.9
Iter-12000; Loss: 136.4
Iter-13000; Loss: 187.7
Iter-14000; Loss: 145.4
Iter-15000; Loss: 144.6
Iter-16000; Loss: 190.2
Iter-17000; Loss: 121.3
Iter-18000; Loss: 191.2
Iter-19000; Loss: 150.3
Iter-20000; Loss: 118.8
Iter-21000; Loss: 93.31
Iter-22000; Loss: 107.5
Iter-23000; Loss: 115.2
Iter-24000; Loss: 109.3
Iter-25000; Loss: 111.1
Iter-26000; Loss: 55.96
Iter-27000; Loss: 130.2
Iter-28000; Loss: 121.1
Iter-29000; Loss: 109.6
Iter-30000; Loss: 51.78
Iter-31000; Loss: 149.2
Iter-32000; Loss: 81.74
Iter-33000; Loss: 123.4
Iter-34000; Loss: 107.9
Iter-35000; Loss: 88.48
Iter-36000; Loss: 151.7
Iter-37000; Loss: 135.0
Iter-38000; Loss: 95.25
Iter-39000; Loss: 101.4
Iter-40000; Loss: 72.87
Iter-41000; Loss: 122.1
Iter-4

KeyboardInterrupt: 