# Module Import

In [1]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data

# Hyperparameters

In [None]:
mb_size = 64 # mini-batch size
Z_dim = 100 # latent space dimension
h_dim = 128 # hidden layer dimension
cnt = 0 # output image counter
lr = 1e-3 # learning rate

# Dataset

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) # get MNIST data from TF
X_dim = mnist.train.images.shape[1] # X dimension, 784
y_dim = mnist.train.labels.shape[1] # Y dimension, 10

# Conditional VAE
![CVAE](complements/CVAE.png)

# Architecture

In [None]:
# Weight Initializer.
# Make weights and initialize them using xavier initialization
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

# Make Encodera
# =============================== Q(z|X) ======================================
# (X,y) -> h
Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

# h -> Z_mu, Z_var
Whz_mu = xavier_init(size=[h_dim, Z_dim]) 
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)
Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X, c):
    inputs = torch.cat([X, c], 1) # (X,y)
    h = nn.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1)) # (X,y) -> h
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1) # h -> z_mu
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1) # h -> z_var
    
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps

# Make Decoder
# =============================== P(X|z) ======================================
# (Z,y) -> h
Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

# h -> X_hat
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def P(z, c):
    inputs = torch.cat([z, c], 1) # (Z,y)
    h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1)) # (Z,y) -> h
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1)) # h -> X_hat
    return X

# Total Loss
![total_loss](complements/total_loss.JPG)
# Reconstruction Loss
![reconstruction_loss](complements/reconstruction_loss.JPG)
# KLD Loss
![KLD2](complements/KLD_analytic2.JPG)
![KLD](complements/KLD_analytic.JPG)


# Training

In [2]:
# =============================== TRAINING ====================================
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

for it in range(100000):
    # get a mini-batch
    X, c = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))
    c = Variable(torch.from_numpy(c.astype('float32')))

    # Forward
    z_mu, z_var = Q(X, c)
    z = sample_z(z_mu, z_var)
    X_sample = P(z, c)

    # Loss
    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping. same as solver.zero_grad()
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0]))

        c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
        c[:, np.random.randint(0, 10)] = 1.
        c = Variable(torch.from_numpy(c))
        z = Variable(torch.randn(mb_size, Z_dim))
        samples = P(z, c).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz
Iter-0; Loss: 740.0
Iter-1000; Loss: 146.3
Iter-2000; Loss: 119.1
Iter-3000; Loss: 113.8
Iter-4000; Loss: 113.9
Iter-5000; Loss: 119.1
Iter-6000; Loss: 108.0
Iter-7000; Loss: 107.8
Iter-8000; Loss: 108.9
Iter-9000; Loss: 108.4
Iter-10000; Loss: 115.0
Iter-11000; Loss: 109.0


KeyboardInterrupt: 