<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Initialization" data-toc-modified-id="Initialization-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Initialization</a></span><ul class="toc-item"><li><span><a href="#MNIST-Dataset-Download-and-normalization" data-toc-modified-id="MNIST-Dataset-Download-and-normalization-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>MNIST Dataset Download and normalization</a></span></li></ul></li><li><span><a href="#Networks" data-toc-modified-id="Networks-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Networks</a></span><ul class="toc-item"><li><span><a href="#Discriminator" data-toc-modified-id="Discriminator-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Discriminator</a></span></li><li><span><a href="#Generator" data-toc-modified-id="Generator-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Generator</a></span></li></ul></li><li><span><a href="#Training" data-toc-modified-id="Training-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Training</a></span></li></ul></div>

## Initialization

In [1]:
import numpy as np
import os
import sys
import torch
from torch import nn, optim
from torch.autograd.variable import Variable as V
from torchvision import datasets, transforms, models, transforms, utils
from matplotlib import pyplot as plt
from matplotlib import gridspec

In [2]:
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
print(sys.version) # Python version
print(torch.cuda.device(0))
torch.cuda.get_device_name(0)

3.6.7 |Anaconda custom (64-bit)| (default, Oct 23 2018, 19:16:44) 
[GCC 7.3.0]
<torch.cuda.device object at 0x7f09a21264a8>


'GeForce GTX 1060'

In [3]:
seed=25
torch.manual_seed(seed)
print("Random Seed: ", seed)

Random Seed:  25


### MNIST Dataset Download and normalization

In [4]:
def get_mnist():
    transform_normalize_tensor=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.50,0.50,0.50),(0.50,0.50,0.50))
        
    ])
    return datasets.MNIST(root='./mnist_dataset', train=True, transform=transform_normalize_tensor, download=True)

In [None]:
# Dataset Loader
mnist_dataset=get_mnist()
mnist_dataloader=torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=128, shuffle=True, drop_last=False)

## Networks

### Discriminator

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        n_features= 784 # Flattened 28*28 image
        n_out=1 # Bool Real or Generated image
        
        self.hidden0=nn.Sequential(
            nn.Linear(n_features,1024),
            nn.SELU()
        )
        self.hidden1=nn.Sequential(
            nn.Linear(1024,512),
            nn.SELU(),
            nn.Dropout(0.25)
        )
        self.hidden2=nn.Sequential(
            nn.Linear(512,128),
            nn.SELU(),
            nn.Dropout(0.25)
        )
        self.out=nn.Sequential(
            nn.Linear(128,n_out),
            nn.Sigmoid()
        )
        
        self.optimizer=optim.Adam(self.parameters(), lr=0.00002)
    
    def forward(self, x):
        x=self.hidden0(x)
        x=self.hidden1(x)
        x=self.hidden2(x)
        x=self.out(x)
        return x
    
    def train(self, real_data, false_data):
        self.optimizer.zero_grad() # Zero/ Reset all gradients
        loss=nn.BCELoss() # Binary Cross-Entropy Loss
        
        # Real data training
        prediction_real=self(real_data) # Prediction
        real_ones_data=V(torch.ones(real_data.size(0),1)).cuda() # Label ones
        error_real=loss(prediction_real,real_ones_data)
        error_real.backward()
        
        # Fake data training
        prediction_false=self(false_data) # Prediction on generated data
        real_zeros_data=V(torch.zeros(false_data.size(0),1)).cuda() # Label zeros
        error_false=loss(prediction_false,real_zeros_data)
        error_false.backward() # Backprop error
        
        self.optimizer.step() # Update weights of Adam
        
        return error_real+error_false,prediction_real,prediction_false # Error, predictions

### Generator

In [None]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        n_features=128 # Latent variable vector size
        n_out=784 # 28*28 input image
                
        self.hidden0=nn.Sequential(
            nn.Linear(n_features,256),
            nn.SELU()
        )
        self.hidden1=nn.Sequential(
            nn.Linear(256,512),
            nn.SELU(),
            nn.Dropout(0.25)
        )
        self.hidden2=nn.Sequential(
            nn.Linear(512,1024),
            nn.SELU(),
            nn.Dropout(0.25)
        )
        self.out=nn.Sequential(
            nn.Linear(1024,n_out),
            nn.Tanh()
        )
        
        self.optimizer=optim.Adam(self.parameters(), lr=0.00002)
    
    def forward(self, x):
        x=self.hidden0(x)
        x=self.hidden1(x)
        x=self.hidden2(x)
        x=self.out(x)
        return x
    
    def train(self, discriminator, false_data):
        self.optimizer.zero_grad() # Zero/ Reset all gradients
        loss=nn.BCELoss() # Binary Cross-Entropy Loss
        
        prediction=discriminator(false_data)
        real_ones_data=V(torch.ones(false_data.size(0),1)).cuda()
        error=loss(prediction, real_ones_data)
        error.backward() # Backprop error
        
        self.optimizer.step() # Update weights of Adam
    
        return error       

## Training 

In [None]:
def sample_noise(sample_size, latent_vec_size):
    return V(torch.randn(sample_size, latent_vec_size)).cuda()

In [None]:
n_epochs=200
total_steps=len(mnist_dataloader)
generated_folder='generated_images/'

discriminator=Discriminator()
generator=Generator()

print("-----Discriminator-----")
print(discriminator,"\n")
print("-----Generator-----")
print(generator,"\n")

if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

print("Training...")

for epoch in range(n_epochs):
    for index, (real_data_batch, _) in enumerate(mnist_dataloader):
        data_batch_size=real_data_batch.size(0)
        
        real_data_batch=V(real_data_batch.view(data_batch_size,784)) # Vectorize+ Variable
        real_data_batch=real_data_batch.cuda() # Cuda compatible
        
        false_data=generator(sample_noise(data_batch_size,128)).detach() # Generate false data...
        # but avoid discriminator gradients being updated by this
        
        dis_error, dis_prediction_real, dis_prediction_false=discriminator.train(real_data_batch, false_data) # Discriminator training
        
        false_data=generator(sample_noise(data_batch_size,128)) # Generate more false data
        gen_error=generator.train(discriminator, false_data) # Generator train
        
        if index%128==0: # Every minibatch
            print("Epoch {}, Batch {}".format(epoch,index))
            print("\tDiscriminator Error: {}, Generator Error: {}".format(dis_error,gen_error))
    if epoch%10==0: # Every 10 epochs
        generated_noise=sample_noise(20, 128)
        generator_vec=generator(generated_noise)
        sample_images=generator_vec.view(generator_vec.size(0),1,28,28).data
        
        fig=plt.figure(figsize=(8,10))
        gs=gridspec.GridSpec(4,5)
        gs.update(wspace=0.05, hspace=0.05)
        
        
        for i, sample_image in enumerate(sample_images):
            ax=plt.subplot(gs[i])
            plt.axis('off')
            plt.imshow(sample_image.cpu().reshape(28,28), cmap='Greys_r')
            
        if not os.path.exists(generated_folder):
            os.makedirs(generated_folder)
          
        figure_name="{}Epoch_{}.png".format(generated_folder, epoch)
        plt.savefig(figure_name, bbox_inches='tight')
        print("Saved {}".format(figure_name))
plt.close(fig)

-----Discriminator-----
Discriminator(
  (hidden0): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): SELU()
  )
  (hidden1): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.25)
  )
  (hidden2): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): SELU()
    (2): Dropout(p=0.25)
  )
  (out): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
    (1): Sigmoid()
  )
) 

-----Generator-----
Generator(
  (hidden0): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): SELU()
  )
  (hidden1): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.25)
  )
  (hidden2): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): SELU()
    (2): Dropout(p=0.25)
  )
  (out): Sequential(
    (0): Linear(in_features=1024, out_features=784, bias

Epoch 18, Batch 128
	Discriminator Error: 0.1556888073682785, Generator Error: 6.477461814880371
Epoch 18, Batch 256
	Discriminator Error: 0.019376033917069435, Generator Error: 7.064450263977051
Epoch 18, Batch 384
	Discriminator Error: 0.01120186410844326, Generator Error: 6.952986240386963
Epoch 19, Batch 0
	Discriminator Error: 0.06769803166389465, Generator Error: 7.483989715576172
Epoch 19, Batch 128
	Discriminator Error: 0.02980879694223404, Generator Error: 8.880510330200195
Epoch 19, Batch 256
	Discriminator Error: 0.08242636173963547, Generator Error: 7.205350399017334
Epoch 19, Batch 384
	Discriminator Error: 0.050204865634441376, Generator Error: 7.526010990142822
Epoch 20, Batch 0
	Discriminator Error: 0.02674749679863453, Generator Error: 8.084227561950684
Epoch 20, Batch 128
	Discriminator Error: 0.04013760760426521, Generator Error: 8.266325950622559
Epoch 20, Batch 256
	Discriminator Error: 0.03606273978948593, Generator Error: 8.937455177307129
Epoch 20, Batch 384
	Di

Epoch 39, Batch 128
	Discriminator Error: 0.2432897835969925, Generator Error: 6.058947563171387
Epoch 39, Batch 256
	Discriminator Error: 0.07984325289726257, Generator Error: 5.395177841186523
Epoch 39, Batch 384
	Discriminator Error: 0.10333366692066193, Generator Error: 5.8905439376831055
Epoch 40, Batch 0
	Discriminator Error: 0.1780472993850708, Generator Error: 5.634584426879883
Epoch 40, Batch 128
	Discriminator Error: 0.09358067810535431, Generator Error: 5.019466400146484
Epoch 40, Batch 256
	Discriminator Error: 0.25910937786102295, Generator Error: 6.410423755645752
Epoch 40, Batch 384
	Discriminator Error: 0.15671634674072266, Generator Error: 5.729916572570801
Saved generated_images/Epoch_40.png
Epoch 41, Batch 0
	Discriminator Error: 0.13614889979362488, Generator Error: 5.095475196838379
Epoch 41, Batch 128
	Discriminator Error: 0.16038109362125397, Generator Error: 6.217350959777832
Epoch 41, Batch 256
	Discriminator Error: 0.05166436731815338, Generator Error: 7.51047

Epoch 60, Batch 128
	Discriminator Error: 0.08694284409284592, Generator Error: 7.600410461425781
Epoch 60, Batch 256
	Discriminator Error: 0.04600934311747551, Generator Error: 6.8558220863342285
Epoch 60, Batch 384
	Discriminator Error: 0.20489326119422913, Generator Error: 6.839653968811035
Saved generated_images/Epoch_60.png
Epoch 61, Batch 0
	Discriminator Error: 0.14296604692935944, Generator Error: 6.928905487060547
Epoch 61, Batch 128
	Discriminator Error: 0.02690882235765457, Generator Error: 8.138692855834961
Epoch 61, Batch 256
	Discriminator Error: 0.07676415145397186, Generator Error: 6.840035438537598
Epoch 61, Batch 384
	Discriminator Error: 0.07021038979291916, Generator Error: 6.54682731628418
Epoch 62, Batch 0
	Discriminator Error: 0.021960876882076263, Generator Error: 9.14721393585205
Epoch 62, Batch 128
	Discriminator Error: 0.06083548069000244, Generator Error: 7.178598880767822
Epoch 62, Batch 256
	Discriminator Error: 0.08570502698421478, Generator Error: 6.7072