In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt

In [None]:
project_root = os.path.realpath('.')
print(project_root)
os.chdir(project_root)

no_cuda = False
cuda_available = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
BATCH_SIZE = 64
EPOCH = 100
SEED = 8

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_available else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [None]:
# Model params
g_input_size = 100     # Random noise dimension coming into generator, per output vector

g_output_size = 784    # size of generated output vector

d_input_size = 784   # Minibatch size - cardinality of distributions

d_output_size = 1    # Single dimension for 'real' vs. 'fake'


d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)

print_interval = 200

d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(g_input_size, 256),
            nn.LeakyReLU(0.2),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            
            nn.Linear(1024, g_output_size),
            nn.Tanh()
        )
        

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = x.view(x.size(0), g_input_size)
        out = self.model(x)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(d_input_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        out = self.model(x.view(x.size(0), d_input_size))
        out = out.view(out.size(0), -1)
        return out

In [None]:

D  = Discriminator().to(device)
G = Generator().to(device)
print(D)
print(G)

In [None]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)

In [None]:
D_losses = []
G_losses = []

def train(epoch):
    D.train()
    G.train()
    
    D_losses.clear()
    G_losses.clear()
    for batch_idx, (data, label) in enumerate(train_loader):
        real_data = data.to(device) 
        D.zero_grad()
        
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            #  1A: Train D on real
            d_real_decision = D(Variable(real_data))
#             if d_real_decision.shape[0] != BATCH_SIZE:
#                 print(d_real_decision.shape)
            d_real_error = criterion(d_real_decision, Variable(torch.ones(data.shape[0],1)).to(device))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params
            
            #  1B: Train D on fake
            d_gen_input = Variable(torch.randn(data.shape[0], g_input_size))
            d_fake_data = G(d_gen_input.to(device)).detach()  # detach to avoid training G on these labels
            
            d_fake_decision = D(d_fake_data.to(device))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(data.shape[0], 1)).to(device))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
            
            d_error = d_real_error + d_fake_error
            D_losses.append(d_error.data)

        for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(torch.randn(data.shape[0], g_input_size).to(device))
            g_fake_data = G(gen_input)
            dg_fake_decision = D(g_fake_data.to(device))
            g_error = criterion(dg_fake_decision, Variable(torch.ones(data.shape[0], 1)).to(device))  # we want to fool, so pretend it's all genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            
            G_losses.append(g_error.data)
        
            
    print('[%d/%d]: D(x): %.3f, D(G(z)): %.3f' % (
        epoch , EPOCH,np.mean(D_losses), np.mean(G_losses)))

In [None]:
def test(epoch):
    G.eval()
    

In [None]:
des_path = os.path.join(project_root, 'results/')
if not os.path.exists(des_path):
    os.makedirs(des_path, exist_ok=True)

import math,  itertools
from IPython import display

size_figure_grid = int(math.sqrt(BATCH_SIZE))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
    ax[i,j].get_xaxis().set_visible(False)
    ax[i,j].get_yaxis().set_visible(False)
    
for epoch in range(1, EPOCH + 1):
    display.clear_output(wait=True)
    train(epoch)
#     test(epoch)
    with torch.no_grad():
        sample = torch.randn(BATCH_SIZE, g_input_size).to(device)
        sample = G(sample).cpu()
        save_image(sample.view(BATCH_SIZE, 1, 28, 28),
                   f'{des_path}epoch_{epoch}.png')
        
        for k in range(BATCH_SIZE):
            i = k//8
            j = k%8
            ax[i,j].cla()
            ax[i,j].imshow(sample[k,:].data.cpu().numpy().reshape(28, 28),cmap='Greys')
        
        display.display(plt.gcf())