In [1]:
import os
import time

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchnet.meter import AverageValueMeter

In [4]:
import mnist

In [5]:
intermediate_path = os.path.join("..", "intermediate", "gan")
if not os.path.exists(intermediate_path):
    os.makedirs(intermediate_path)

In [6]:
mb_size = 64
epochs = 100
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
lr = 1e-3

In [7]:
def xavier_init(size):
    in_dim = size[0]
    out_dim = size[1]
    xavier_stddev = 1. / np.sqrt((in_dim+out_dim) / 2.)
    return Variable(torch.randn(*size) * xavier_stddev,
                    requires_grad=True)

In [8]:
Wzh = xavier_init(size=[Z_dim, h_dim])
Whx = xavier_init(size=[h_dim, X_dim])

In [9]:
bzh = Variable(torch.zeros(h_dim), requires_grad=True)
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [10]:
def G(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [11]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = F.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y

In [12]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

In [13]:
def reset_grad():
    for p in params:
        p.grad.data.zero_()

In [14]:
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

In [15]:
train_loader = DataLoader(
    mnist.MNIST('../data', train=True, download=True,
                transform=transforms.ToTensor()),
    batch_size=mb_size, shuffle=True)

In [16]:
def train(epoch):
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    start = time.time()
    
    for i, (X, _) in enumerate(train_loader):
        ones_label = Variable(torch.ones(X.size(0)))
        zeros_label = Variable(torch.zeros(X.size(0)))
        X = X.view(-1, 784)
        X = Variable(X)
        # Sample data
        z = Variable(torch.randn(X.size(0), Z_dim))
        
        # Dicriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)
        
        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        
        D_loss = D_loss_real + D_loss_fake
        
        D_loss.backward()
        D_solver.step()
        reset_grad()
        
        # Generator forward-loss-backward-update
        z = Variable(torch.randn(X.size(0), Z_dim))
        G_sample = G(z)
        D_fake = D(G_sample)
        
        G_loss = F.binary_cross_entropy(D_fake, ones_label)
        
        G_loss.backward()
        G_solver.step()
        reset_grad()
        
        D_losses.add(D_loss.data.cpu()[0] * X.size(0), X.size(0))
        G_losses.add(G_loss.data.cpu()[0] * X.size(0), X.size(0))
        
    print("   * EPOCH {} | Time: {}s | D_loss: {:.4f} | G_loss: {:.4f}"
          .format(epoch, round(time.time()-start),
                  D_losses.value()[0],
                  G_losses.value()[0]))

In [None]:
def plot(samples, epoch):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis("off")
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect("equal")
        plt.imshow(sample.reshape(28, 28), cmap="Greys_r")

    out_path = os.path.join(intermediate_path, "out")
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    out_filepath = os.path.join(out_path, 
                                "{}.png".format(str(epoch).zfill(3)))
    plt.savefig(out_filepath, bbox_inches='tight')
    plt.close(fig)

In [None]:
for epoch in range(1, epochs+1):
    train(epoch)
    z = torch.randn(mb_size, Z_dim)
    samples = G(Variable(z)).data.numpy()[:16]
    plot(samples, epoch)

   * EPOCH 1 | Time: 25s | D_loss: 0.0885 | G_loss: 5.3372
   * EPOCH 2 | Time: 26s | D_loss: 0.0639 | G_loss: 5.1707
   * EPOCH 3 | Time: 25s | D_loss: 0.1065 | G_loss: 4.5663
   * EPOCH 4 | Time: 25s | D_loss: 0.2214 | G_loss: 4.6510
   * EPOCH 5 | Time: 26s | D_loss: 0.3088 | G_loss: 4.1979
   * EPOCH 6 | Time: 25s | D_loss: 0.4808 | G_loss: 3.7872
   * EPOCH 7 | Time: 25s | D_loss: 0.5365 | G_loss: 3.1798
   * EPOCH 8 | Time: 26s | D_loss: 0.5744 | G_loss: 3.1267
   * EPOCH 9 | Time: 25s | D_loss: 0.6274 | G_loss: 2.9110
   * EPOCH 10 | Time: 25s | D_loss: 0.6818 | G_loss: 2.7613
   * EPOCH 11 | Time: 25s | D_loss: 0.6946 | G_loss: 2.6511
   * EPOCH 12 | Time: 26s | D_loss: 0.7201 | G_loss: 2.7746
   * EPOCH 13 | Time: 25s | D_loss: 0.7558 | G_loss: 2.6204
   * EPOCH 14 | Time: 25s | D_loss: 0.7741 | G_loss: 2.3711
   * EPOCH 15 | Time: 25s | D_loss: 0.7586 | G_loss: 2.2646
   * EPOCH 16 | Time: 26s | D_loss: 0.7686 | G_loss: 2.1683
   * EPOCH 17 | Time: 26s | D_loss: 0.7824 | G_lo