# Adversarial (Variational) AutoEncoder

In [6]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data

## Hyperparameters

In [7]:
mb_size = 32
z_dim = 5
h_dim = 128
cnt = 0
lr = 1e-3

## Dataset

In [8]:
mnist = input_data.read_data_sets('../dataset/raw', one_hot=True)
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]

Extracting ../dataset/raw/train-images-idx3-ubyte.gz
Extracting ../dataset/raw/train-labels-idx1-ubyte.gz
Extracting ../dataset/raw/t10k-images-idx3-ubyte.gz
Extracting ../dataset/raw/t10k-labels-idx1-ubyte.gz


## Architecture

In [9]:
# Encoder
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)
).cuda()

# Decoder
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
).cuda()

# Discriminator
D = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid()
).cuda()

def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

# Training

In [None]:
def sample_X(size, include_y=False):
    X, y = mnist.train.next_batch(size)
    X = Variable(torch.from_numpy(X).cuda())

    if include_y:
        y = np.argmax(y, axis=1).astype(np.int)
        y = Variable(torch.from_numpy(y).cuda())
        return X, y
    return X

Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

for it in range(1000000):
    X = sample_X(mb_size)

    """ Reconstruction phase """
    z_sample = Q(X)
    X_sample = P(z_sample)

    recon_loss = nn.binary_cross_entropy(X_sample, X)

    recon_loss.backward()
    P_solver.step()
    Q_solver.step()
    reset_grad()

    """ Regularization phase """
    # Discriminator
    z_real = Variable(torch.randn(mb_size, z_dim).cuda())
    z_fake = Q(X)

    D_real = D(z_real)
    D_fake = D(z_fake)

    D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    z_fake = Q(X)
    D_fake = D(z_fake)

    G_loss = -torch.mean(torch.log(D_fake))

    G_loss.backward()
    Q_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
              .format(it, D_loss.data[0], G_loss.data[0], recon_loss.data[0]))

        samples = P(z_real).data.cpu().numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: 1.418; G_loss: 0.6553; recon_loss: 0.6951
Iter-1000; D_loss: 1.432; G_loss: 0.6595; recon_loss: 0.2643
Iter-2000; D_loss: 1.492; G_loss: 0.617; recon_loss: 0.2402
Iter-3000; D_loss: 1.406; G_loss: 0.7088; recon_loss: 0.2351
Iter-4000; D_loss: 1.391; G_loss: 0.7011; recon_loss: 0.2006
Iter-5000; D_loss: 1.395; G_loss: 0.6994; recon_loss: 0.1826
Iter-6000; D_loss: 1.39; G_loss: 0.6938; recon_loss: 0.1626
Iter-7000; D_loss: 1.39; G_loss: 0.6988; recon_loss: 0.1623
Iter-8000; D_loss: 1.373; G_loss: 0.7082; recon_loss: 0.1502
Iter-9000; D_loss: 1.403; G_loss: 0.6876; recon_loss: 0.1793
Iter-10000; D_loss: 1.386; G_loss: 0.6903; recon_loss: 0.1541
Iter-11000; D_loss: 1.397; G_loss: 0.6744; recon_loss: 0.1693
Iter-12000; D_loss: 1.394; G_loss: 0.6821; recon_loss: 0.1707
Iter-13000; D_loss: 1.39; G_loss: 0.7085; recon_loss: 0.17
Iter-14000; D_loss: 1.377; G_loss: 0.7116; recon_loss: 0.155
Iter-15000; D_loss: 1.384; G_loss: 0.6989; recon_loss: 0.1768
Iter-16000; D_loss: 1.391; G