# Vanila GAN on MNIST

In [1]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

## Data preparation

In [2]:
# download & normalize data

compose = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((.5, .5, .5), (.5, .5, .5))])
data = datasets.MNIST(root = './dataset', train = True, 
                      transform = compose, download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset\MNIST\raw\train-images-idx3-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



Extracting ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
#Create data loader with data

data_loader = torch.utils.data.DataLoader(data, batch_size = 100, shuffle = True)
num_batches = len(data_loader) #600

## Design Network

### Discriminator

In [11]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 28 * 28
        n_out = 1
        
        self.layer1 = nn.Sequential(
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.out(x)
        return (x)

### Generator

In [12]:
class GeneratorNet(nn.Module):
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.layer1 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh() #to map value into (-1, 1) range
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.out(x)
        return (x)

In [13]:
discriminator = DiscriminatorNet()
generator = GeneratorNet()

## Training

### Optimizer

In [14]:
d_optimizer = optim.Adam(discriminator.parameters(), lr = 2e-4)
g_optimizer = optim.Adam(generator.parameters(), lr = 2e-4)

### Train function for discriminator

#### Loss function
Use Binary Cross Entropy loss

L = y * log(x) + (1 - y) * log(1 - x)

to change this to discriminator loss, we only have to put right label

In [16]:
loss = nn.BCELoss() # L = y * log(x) + (1 - y) * log(1 - x)

In [17]:
#label real and fake batch with ones and zeros
def ones_target(size):
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size, 1))
    return data

In [19]:
def train_discriminator(optimizer, real_data, fake_data):
    
    N = real_data.size(0)
    optimizer.zero_grad()
    
    #train on real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, ones_target(N))
    error_real.backward()
    
    #train on fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward() #gradient accumulated
    
    optimizer.step()
    
    return error_real + error_fake, prediction_real, prediction_fake

### Train function for Generator

#### Loss function
As mentioned in the paper, rather than minimizing log(1 - D(G(z)), maximize log D(G(z)) to avoid gradient saturation problem


In [20]:
def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    prediction = discriminator(fake_data)
    error = loss(prediction, ones_target(N))
    error.backward()
    optimizer.step()
    
    return error

### Aditional functions

In [None]:
## Add noise to inputs, decay over time
def noise(size):
    n = Variable(torch.randn(size, 100))