<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Initialization" data-toc-modified-id="Initialization-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Initialization</a></span><ul class="toc-item"><li><span><a href="#MNIST-Dataset-Download-and-normalization" data-toc-modified-id="MNIST-Dataset-Download-and-normalization-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>MNIST Dataset Download and normalization</a></span></li></ul></li><li><span><a href="#Networks" data-toc-modified-id="Networks-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Networks</a></span><ul class="toc-item"><li><span><a href="#Discriminator" data-toc-modified-id="Discriminator-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Discriminator</a></span></li><li><span><a href="#Generator" data-toc-modified-id="Generator-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Generator</a></span></li></ul></li></ul></div>

## Initialization

In [1]:
import numpy as np
import os
import sys
import torch
from torch import nn, optim
from torch.autograd.variable import Variable as V
from torchvision import datasets, transforms, models, transforms, utils

In [2]:
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
print(sys.version) # Python version
print(torch.cuda.device(0))
torch.cuda.get_device_name(0)

3.6.7 |Anaconda custom (64-bit)| (default, Oct 23 2018, 19:16:44) 
[GCC 7.3.0]
<torch.cuda.device object at 0x7f68ae3714e0>


'GeForce GTX 1060'

### MNIST Dataset Download and normalization

In [3]:
def get_mnist():
    transform_normalize_tensor=transforms.Compose([
        transforms.Normalize((0.50,0.50,0.50),(0.50,0.50,0.50)),
        transforms.ToTensor()
    ])
    return datasets.MNIST(root='./mnist_dataset', train=True, transform=transform_normalize_tensor, download=True)

In [4]:
# Dataset Loader
mnist_dataset=get_mnist()
mnist_dataloader=torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=128, shuffle=True, drop_last=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


## Networks

### Discriminator

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        n_features= 784 # Flattened 28*28 image
        n_out=1 # Bool Real or Generated image
        
        self.hidden0=nn.Sequential(
            nn.Linear(n_features,1024),
            nn.selu()
        )
        self.hidden1=nn.Sequential(
            nn.Linear(1024,512),
            nn.selu(),
            nn.Dropout(0.25)
        )
        self.hidden2=nn.Sequential(
            nn.Linear(512,128),
            nn.selu(),
            nn.Dropout(0.25)
        )
        self.out=nn.Sequential(
            nn.Linear(128,n_out),
            nn.Sigmoid()
        )
        
        self.optimizer=optim.Adam(self.parameters(), lr=0.00002)
    
    def forward(self,x):
        x=self.hidden0(x)
        x=self.hidden1(x)
        x=self.hidden2(x)
        x=self.out(x)
        return x
    
    def train(self, real_data, false_data):
        self.optimizer.zero_grad() # Zero/ Reset all gradients
        loss=nn.BCELoss() # Binary Cross-Entropy Loss
        
        # Real data training
        prediction_real=self(real_data) # Prediction
        real_ones_data=V(torch.ones(real_data.size(0),1)) # Label ones
        error_real=loss(prediction_real,real_ones_data)
        error_real.backward()
        
        # Fake data training
        prediction_false=self(false_data) # Prediction on generated data
        real_zeros_data=V(torch.zeros(false_data.size(0),1)) # Label zeros
        error_false=loss(prediction_false,real_zeros_data)
        error_false.backward()
        
        self.optimizer.step() # Update weights of Adam
        
        return error_real+error_false,prediction_real,prediction_false # Error, predictions

### Generator

In [None]:
class Generator(torch.nn.Module):