In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import importlib
import argparse

In [2]:
import torch
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms
from torchnet.meter import AverageValueMeter

In [3]:
from tensorflow.examples.tutorials.mnist import input_data

In [4]:
%cd ../datasets
import mnist; importlib.reload(mnist)

/Users/hoangnguyen/Documents/Github/generative-model/datasets


<module 'mnist' from '/Users/hoangnguyen/Documents/Github/generative-model/datasets/mnist.py'>

# Load data

In [5]:
parser = {
    'batch_size': 64,
    'no_cuda': True,
    'epochs': 20
}
args = argparse.Namespace(**parser)

In [6]:
mb_size = args.batch_size
epochs = args.epochs
Z_dim = 100    #Choose 
X_dim = 784
y_dim = 10
h_dim = 128    #Choose
cnt = 0
lr = 1e-3
noise_factor = .25

In [7]:
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [8]:
train_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=True,
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(mnist.MNIST('../data', train=False, 
                transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True)

In [9]:
#test DataLoader
for x, y in train_loader:
    print(x)
    break


(0 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
     ⋮ 

(1 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
     ⋮ 

(2 ,0 ,.,.) = 
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
 

In [10]:
def xavier_init(size):
    in_dim = size[0]
    out_dim = size[1]
    xavier_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

Instead of using function ```xavier_init```, we can use built-in [functions](http://pytorch.org/docs/nn.html#torch-nn-init) in pytorch
- ```torch.nn.init.xavier_uniform(tensor, gain=1)```
- ```torch.nn.init.xavier_normal(tensor, gain=1)```

We will use ```torch.nn.init.xavier_normal(tensor, gain=1)``` when pytorch release the module

# Generator

z --> G --> X

In [11]:
Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

In [12]:
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [13]:
def G(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

# Discriminator

x --> D --> y

In [14]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

In [15]:
Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)

In [16]:
def D(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = F.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y

# Training

In [17]:
def onehot(y):
    y_array = np.zeros(shape=[y.size(0), 10], dtype='float32')
    for index, ele in enumerate(y):
        y_array[index][ele] = 1
    return torch.from_numpy(y_array)

```python
#Initialize zero buffers
   for p in params:
       p.grad.data.zero_()
```

In [18]:
# Use ```reset_grad``` to initialize zero buffers for params.
def reset_grad():
    for p in params:
        p.grad.data.zero_()

In [19]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

In [20]:
G_solver = optim.Adam(G_params, lr=lr)
D_solver = optim.Adam(D_params, lr=lr)

In [21]:
def train(epoch):
    # Sample data
    #X.size(0) = batch_size
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    for X, _ in train_loader:
        # Create ones_label and zeros_label
        ones_label = Variable(torch.ones(X.size(0)))
        zeros_label = Variable(torch.zeros(X.size(0)))
        
        # Input: z - latent variables, x - input
        z = Variable(torch.randn(X.size(0), Z_dim))
        X = Variable(X.view(-1, 784))

        # Dicriminator forward-loss-backward-update
        G_sample = G(z) # X_fake: generate from Generator
        D_real = D(X)
        D_fake = D(G_sample)
        
        # Calculate loss
        D_loss_real = F.binary_cross_entropy(D_real, ones_label) # compare D_real with 1
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label) # compare D_fake with 0
        D_loss = D_loss_real + D_loss_fake

        # Housekeeping - reset gradient
        reset_grad()
        
        # Tinh dao ham cua D_loss vs cac Variable require_grad = true
        D_loss.backward()
        
        # update params
        D_solver.step()

        #---------------------------------------------------#
        
        # Generator forward-loss-backward-update
        z = Variable(torch.randn(X.size(0), Z_dim))
        G_sample = G(z)
        D_fake = D(G_sample)

        G_loss = F.binary_cross_entropy(D_fake, ones_label) # Compare D_fake with 1

        # Housekeeping - reset gradient
        reset_grad()
        
        # Back-ward
        G_loss.backward()
        
        # Update
        G_solver.step()
        
        #D_losses.add(D_loss.data[0], X.size(0))
        #G_losses.add(G_loss.data[0], X.size(0))
        
        # Test A. Du's loss
        D_losses.add(D_loss.data[0]*X.size(0), X.size(0))
        G_losses.add(G_loss.data[0]*X.size(0), X.size(0))

    print('Epoch-{}; D_loss: {}; G_loss: {}'.format(epoch, D_losses.value()[0], G_losses.value()[0]))

In [22]:
def plot(samples, epoch):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    if not os.path.exists('../gan/out/'):
        os.makedirs('../gan/out/')
    
    fileName = '../gan/out/' + str(epoch).zfill(3)
    
    plt.savefig(fileName, bbox_inches='tight')
    #plt.show()
    plt.close(fig)

In [23]:
for epoch in range(1, args.epochs+1):
    train(epoch)
    z = Variable(torch.randn(mb_size, Z_dim))
    samples = G(z).data.numpy()[:16]    #only plot 16 pictures in each epoch
    plot(samples, epoch)

Epoch-1; D_loss: 0.05942605794072151; G_loss: 5.846674840037028
Epoch-2; D_loss: 0.04273620020250479; G_loss: 5.993272939300537
Epoch-3; D_loss: 0.1093657682498296; G_loss: 4.990968609619141
Epoch-4; D_loss: 0.1447583189924558; G_loss: 5.314013683573405
Epoch-5; D_loss: 0.19785221904317538; G_loss: 4.950421976979573
Epoch-6; D_loss: 0.3372966682434082; G_loss: 4.064016493860881
Epoch-7; D_loss: 0.40511250575383506; G_loss: 3.8313508449554443
Epoch-8; D_loss: 0.4513482684135437; G_loss: 3.6127767100016275
Epoch-9; D_loss: 0.5479822627385458; G_loss: 3.344304447555542
Epoch-10; D_loss: 0.6177612986405691; G_loss: 3.2032399148305255
Epoch-11; D_loss: 0.650471140384674; G_loss: 2.9522206104278563
Epoch-12; D_loss: 0.6978705864270528; G_loss: 2.733676774851481
Epoch-13; D_loss: 0.7035380304336548; G_loss: 2.5021265343983967
Epoch-14; D_loss: 0.7354028270721436; G_loss: 2.3772855337778727
Epoch-15; D_loss: 0.7628123646100362; G_loss: 2.273324891535441
Epoch-16; D_loss: 0.7700538780530294; G_