# Module Import

In [1]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data

  return f(*args, **kwds)


# Hyperparameters

In [2]:
mb_size = 32
z_dim = 5
h_dim = 128
cnt = 0
lr = 1e-3

# Dataset

In [3]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


# Architecture

In [None]:
# Encoder
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)
).cuda()

# Decoder
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
).cuda()

# Discriminator
D = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid()
).cuda()


def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

# Training

In [None]:

def sample_X(size, include_y=False):
    X, y = mnist.train.next_batch(size)
    X = Variable(torch.from_numpy(X).cuda())

    if include_y:
        y = np.argmax(y, axis=1).astype(np.int)
        y = Variable(torch.from_numpy(y).cuda())
        return X, y
    return X


Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)


for it in range(1000000):
    X = sample_X(mb_size)

    """ Reconstruction phase """
    z_sample = Q(X)
    X_sample = P(z_sample)

    recon_loss = nn.binary_cross_entropy(X_sample, X)

    recon_loss.backward()
    P_solver.step()
    Q_solver.step()
    reset_grad()

    """ Regularization phase """
    # Discriminator
    z_real = Variable(torch.randn(mb_size, z_dim).cuda())
    z_fake = Q(X)

    D_real = D(z_real)
    D_fake = D(z_fake)

    D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    z_fake = Q(X)
    D_fake = D(z_fake)

    G_loss = -torch.mean(torch.log(D_fake))

    G_loss.backward()
    Q_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
              .format(it, D_loss.data[0], G_loss.data[0], recon_loss.data[0]))

        samples = P(z_real).data.cpu().numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: 1.377; G_loss: 0.6489; recon_loss: 0.6978
Iter-1000; D_loss: 1.402; G_loss: 0.6571; recon_loss: 0.2582
Iter-2000; D_loss: 1.481; G_loss: 0.6294; recon_loss: 0.274
Iter-3000; D_loss: 1.59; G_loss: 0.7361; recon_loss: 0.2617
Iter-4000; D_loss: 1.187; G_loss: 0.7721; recon_loss: 0.2523
Iter-5000; D_loss: 1.149; G_loss: 0.873; recon_loss: 0.2692
Iter-6000; D_loss: 1.214; G_loss: 0.7575; recon_loss: 0.2401
Iter-7000; D_loss: 1.439; G_loss: 0.6357; recon_loss: 0.2758
Iter-8000; D_loss: 0.9492; G_loss: 1.102; recon_loss: 0.2601
Iter-9000; D_loss: 0.7126; G_loss: 1.554; recon_loss: 0.2479
Iter-10000; D_loss: 1.044; G_loss: 0.9682; recon_loss: 0.2577
Iter-11000; D_loss: 1.516; G_loss: 1.162; recon_loss: 0.2495
Iter-12000; D_loss: 0.6631; G_loss: 2.372; recon_loss: 0.2643
Iter-13000; D_loss: 0.2539; G_loss: 2.658; recon_loss: 0.2733
Iter-14000; D_loss: 0.547; G_loss: 2.133; recon_loss: 0.2539
Iter-15000; D_loss: 0.9916; G_loss: 1.482; recon_loss: 0.2661
Iter-16000; D_loss: 0.5827