# Module Import

In [1]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data

  return f(*args, **kwds)


# Hyperparameters

In [2]:
mb_size = 64 # mini-batch size
Z_dim = 100 # latent space dimension
h_dim = 128 # hidden layer dimension
cnt = 0 # output image counter
lr = 1e-3 # learning rate

# Dataset

In [3]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) # get MNIST data from TF
X_dim = mnist.train.images.shape[1] # X dimension, 784
y_dim = mnist.train.labels.shape[1] # Y dimension, 10

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


# Conditional VAE
![CVAE](complements/CVAE.png)

# Architecture

In [4]:
# Make Encodera
# =============================== Q(z|X) ======================================

encoder = nn.Sequential(
    nn.Linear(X_dim+y_dim,h_dim),
    nn.ReLU(True),
    nn.Linear(h_dim,Z_dim*2),
)
encoder.cuda()

def Q(X, c):
    inputs = torch.cat([X, c], 1)# (X,y)
    z = encoder(inputs)
    z_mu = z[:,:Z_dim]
    z_var = z[:,Z_dim:]
    
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps.cuda()

# Make Decoder
# =============================== P(X|z) ======================================
# (Z,y) -> h
decoder = nn.Sequential(
    nn.Linear(Z_dim+y_dim,h_dim),
    nn.ReLU(True),
    nn.Linear(h_dim,X_dim),
    nn.Sigmoid()
)
decoder.cuda()

def P(z, c):
    inputs = torch.cat([z, c], 1) # (Z,y)
    X = decoder(inputs)
    return X



# Total Loss
![total_loss](complements/total_loss.JPG)
# Reconstruction Loss
![reconstruction_loss](complements/reconstruction_loss.JPG)
# KLD Loss
![KLD2](complements/KLD_analytic2.JPG)
![KLD](complements/KLD_analytic.JPG)


# Training

In [None]:
# =============================== TRAINING ====================================
solver = optim.Adam([{'params':encoder.parameters()},
                     {'params':decoder.parameters()}]
                    , lr=lr)

for it in range(100000):
    # get a mini-batch
    X, c = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X)).cuda()
    c = Variable(torch.from_numpy(c.astype('float32'))).cuda()

    # Forward
    z_mu, z_var = Q(X, c)
    z = sample_z(z_mu, z_var)
    X_sample = P(z, c)

    # Loss
    recon_loss = nn.functional.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping. same as solver.zero_grad()
    solver.zero_grad()
    
    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0]))

        c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
        c[:, np.random.randint(0, 10)] = 1.
        c = Variable(torch.from_numpy(c)).cuda()
        z = Variable(torch.randn(mb_size, Z_dim)).cuda()
        samples = P(z, c).data.cpu().numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; Loss: 531.5
Iter-1000; Loss: 144.7
Iter-2000; Loss: 118.3
Iter-3000; Loss: 124.9
Iter-4000; Loss: 117.9
Iter-5000; Loss: 104.4
Iter-6000; Loss: 111.2
Iter-7000; Loss: 107.6
Iter-8000; Loss: 113.2
Iter-9000; Loss: 105.1
Iter-10000; Loss: 104.8
Iter-11000; Loss: 109.1
Iter-12000; Loss: 114.4
Iter-13000; Loss: 105.7
Iter-14000; Loss: 107.3
Iter-15000; Loss: 105.7
Iter-16000; Loss: 104.1
Iter-17000; Loss: 102.2
Iter-18000; Loss: 99.47
Iter-19000; Loss: 107.3
Iter-20000; Loss: 104.2
Iter-21000; Loss: 100.8
Iter-22000; Loss: 98.74
Iter-23000; Loss: 106.2
Iter-24000; Loss: 105.1
Iter-25000; Loss: 96.41
Iter-26000; Loss: 105.4
Iter-27000; Loss: 102.9
Iter-28000; Loss: 105.3
Iter-29000; Loss: 98.21
Iter-30000; Loss: 103.1
Iter-31000; Loss: 97.53
Iter-32000; Loss: 101.2
Iter-33000; Loss: 102.2
Iter-34000; Loss: 105.6
Iter-35000; Loss: 102.8
Iter-36000; Loss: 110.6
Iter-37000; Loss: 101.8
