In [28]:
import os
import argparse
import importlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms
from torchnet.meter import AverageValueMeter
import numpy as np

In [30]:
import mnist; importlib.reload(mnist)

<module 'mnist' from '/home/hminle/Github/weekly-ml-projects/gan/mnist.py'>

In [31]:
# Define parser arguments
parser = {
    "batch_size": 64,
    "epochs": 10,
    "no_cuda": False,
    "seed": 1,
    "log_interval": 10,
}

In [32]:
# Parse arguments for  model
args = argparse.Namespace(**parser) # parse arguments
args.cuda = not args.no_cuda and torch.cuda.is_available()

## Load Data

In [33]:
train_loader = torch.utils.data.DataLoader(
    mnist.MNIST('../data', train=True, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    mnist.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True)

## Define Parameters

In [34]:
mb_size = args.batch_size
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
c = 0
lr = 1e-3

## Generator

In [35]:
def xavier_init(size):
    in_dim = size[0]
    out_dim = size[1]
    xavier_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [36]:
Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [37]:
def G(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

## DISCRIMINATOR

In [38]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = F.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y

## Model Params

In [39]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

## TRAINING

In [40]:
def reset_grad():
    for p in params:
        p.grad.data.zero_()


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)


In [45]:
vis = visdom.Visdom()

In [59]:
def train(epoch):
        # Sample data
        #X.size(0) = batch_size
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    for X, _ in train_loader:
        ones_label = Variable(torch.ones(X.size(0)))
        zeros_label = Variable(torch.zeros(X.size(0)))
        
        z = Variable(torch.randn(X.size(0), Z_dim))
        X = Variable(X.view(-1, 784))

        # Dicriminator forward-loss-backward-update
        G_sample = G(z) # X_fake
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss_real = F.binary_cross_entropy(D_real, ones_label) # compare D_real with 1
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label) # compare D_fake with 0
        D_loss = D_loss_real + D_loss_fake

        # Housekeeping - reset gradient
        reset_grad()
        # Tinh dao ham cua D_loss vs cac Variable require_grad = true
        D_loss.backward()
        # update params
        D_solver.step()



        # Generator forward-loss-backward-update
        z = Variable(torch.randn(X.size(0), Z_dim))
        G_sample = G(z)
        D_fake = D(G_sample)

        G_loss = F.binary_cross_entropy(D_fake, ones_label) # Compare D_fake with 1

        # Housekeeping - reset gradient
        reset_grad()
        G_loss.backward()
        G_solver.step()
        
        D_losses.add(D_loss.data[0]*X.size(0), X.size(0))
        G_losses.add(G_loss.data[0]*X.size(0), X.size(0))
    print('Epoch-{}; D_loss: {}; G_loss: {}'
        .format(epoch, D_losses.value()[0], G_losses.value()[0]))

In [60]:
def plot(samples, epoch):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    if not os.path.exists('out_GAN/'):
        os.makedirs('out_GAN/')

    plt.savefig('out_GAN/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
    plt.close(fig)

In [58]:
#Before training, draw one figure
z = Variable(torch.randn(mb_size, Z_dim))
samples = G(z).data.numpy()[:16]
plot(samples, 999)

for epoch in range(0,args.epochs):
    train(epoch)
    z = Variable(torch.randn(mb_size, Z_dim))
    samples = G(z).data.numpy()[:16]
    plot(samples, epoch)

Epoch-0; D_loss: 0.7881988826433818; G_loss: 1.9799039277394612
Epoch-1; D_loss: 0.7781565422693888; G_loss: 1.9704048128763834
Epoch-2; D_loss: 0.7646393604596455; G_loss: 2.027835792605082
Epoch-3; D_loss: 0.7684109745661417; G_loss: 2.058323233795166
Epoch-4; D_loss: 0.7580886743863424; G_loss: 2.045337483215332
Epoch-5; D_loss: 0.7541149898846944; G_loss: 2.080440088526408
Epoch-6; D_loss: 0.7459175700505575; G_loss: 2.108483600997925
Epoch-7; D_loss: 0.7382791342099507; G_loss: 2.143351516977946
Epoch-8; D_loss: 0.7307715403556824; G_loss: 2.171455207443237
Epoch-9; D_loss: 0.7228711834271749; G_loss: 2.1784596199035646
