In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

from torch import nn
from torch import optim
from torchvision.datasets import MNIST
from torchvision import transforms

%matplotlib inline

In [None]:
# check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Generator and Discriminator Architecture

In [None]:
# define constants and hyper-parameters
# for Discriminator
IMG_SIZE = 784
DL1 = 500
DLO = 1

d_model = nn.Sequential(
    nn.Linear(IMG_SIZE, DL1),
    nn.ReLU(),
    nn.Linear(DL1, DLO),
    nn.Sigmoid()
)

d_model = d_model.cuda(device=device)

# define constants and hyper-parameters
# for Generator
NOISE_SIZE = 64
GL1 = 500
GL2 = 500
GLO = 784

g_model = nn.Sequential(
    nn.Linear(NOISE_SIZE, GL1),
    nn.LeakyReLU(0.2),
    nn.Linear(GL1, GL2),
    nn.LeakyReLU(0.2),
    nn.Linear(GL2, GLO),
    nn.Tanh()
)

g_model = g_model.cuda(device=device)

## Loss Functions and Optimizers

In [None]:
# loss functions and optimizers
d_lr = 0.0002
g_lr = 0.0002

d_loss_fn = nn.BCELoss()
d_opt = optim.SGD(d_model.parameters(), lr=d_lr, momentum=0.9)

g_loss_fn = nn.BCELoss()
g_opt = optim.Adam(g_model.parameters(), lr=g_lr)

## MNIST Dataloaders 

In [None]:
# more hyper-parameters
BATCH_SIZE = 10

In [None]:
# train data iterator
train_loader = torch.utils.data.DataLoader(
    MNIST(
        './data', 
        train=True, 
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    ), 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4
)

## Utility Functions

In [None]:
# utility functions
def extract(v):
    return v.data.storage().tolist()

def denorm(x):
    x = (x + 1)/2
    return x.clamp(0, 1)
    
def sample_images():
    n = 20
    for i in range(n):
        fake_images = g_model(torch.randn([n, NOISE_SIZE], device=device))
        fake_images = denorm(fake_images)
        fake_images = fake_images.reshape(n, 28, 28)
        
        plt.subplot(4, 5, i+1)
        plt.imshow(denorm(fake_images[i]).detach().to('cpu').numpy())
    
    plt.pause(0.5)
    plt.close()
    
def sample_z(batch_size, noise_size, device):
    return torch.autograd.Variable(torch.distributions.Normal(0, 2).sample([batch_size, noise_size])).to(device)

## Training Parameters

In [None]:
# training parameters
epochs = 300
d_steps = 1
g_steps = 1

print_iter = 200
sample_iter = 500

dlr, dlf, gl, dx, dgz = 0, 0, 0, 0, 0

## Training Loop

In [None]:
# training loop
for epoch in range(epochs):
    for i, (real_images, real_labels) in enumerate(train_loader):
        # clear gradients
        d_model.zero_grad()
        g_model.zero_grad()

        # train on real data
        real_images = real_images.reshape(BATCH_SIZE, -1).to(device) # X

        d_x = d_model(torch.autograd.Variable(real_images)) # D(X)
        d_loss_real = d_loss_fn(d_x, torch.autograd.Variable(torch.ones([BATCH_SIZE, 1], device=device))) # log(D(X))
        d_loss_real.backward()

        # train on fake data
        fake_images = g_model(sample_z(BATCH_SIZE, NOISE_SIZE, device)) # G(Z)

        d_g_z = d_model(fake_images) # D(G(Z))
        d_loss_fake = d_loss_fn(d_g_z, torch.autograd.Variable(torch.zeros([BATCH_SIZE, 1], device=device))) # log(1 - D(G(Z)))
        d_loss_fake.backward()

        d_opt.step()

        dlr, dlf = extract(d_loss_real)[0], extract(d_loss_fake)[0]
    
        # clear gradients
        d_model.zero_grad()
        g_model.zero_grad()

        # train on fake data
        fake_images = g_model(sample_z(BATCH_SIZE, NOISE_SIZE, device)) # G(Z)

        d_g_z = d_model(fake_images) # D(G(Z))
        g_loss = g_loss_fn(d_g_z, torch.autograd.Variable(torch.ones([BATCH_SIZE, 1], device=device))) # log(1 - D(G(Z)))
        g_loss.backward()
        g_opt.step()

        gl = extract(g_loss)[0]
    
        if i % print_iter == 0:
            print("Epoch %s/%s: DReal: %s, DFake: %s, G: %s;    Dx: %s, Dg: %s" % (epoch, i, dlr, dlf, gl, d_x.mean(), d_g_z.mean()))

        if i % sample_iter == 0:
            sample_images()
        