In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


## Loading the data
I load the MNIST dataset and normalize the pixel values to be between -1 and 1 because I am going to use Tanh as an activation function. Then I split the dataset to train and test sets.

In [2]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_train = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size = 100, shuffle = True)


## Generator
I rewrite the architecture described in the paper in PyTorch (“We trained a multilayer perceptron with 100 input units, two hidden layers of 1200 ReLU units each, and an output layer with 784 sigmoid units.”), where I initialize the pass by creating a noise vector. I use Tanh function instead of Sigmoid for stability (inspired by - __[link to the blog post on towardsdatascience](https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739/)__ )

In [3]:

class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.hidden1 = nn.Linear(noise_dim, 1200)
        self.hidden2 = nn.Linear(1200, 1200)
        self.output = nn.Linear(1200, 784)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):
        x = self.relu(self.hidden1(z))
        x = self.relu(self.hidden2(x))
        x = self.tanh(self.output(x))
        return x

## Discriminator
For the discriminator I also rewrite te architecture from the architecture, with a slight change - LeakyReLU instead of maxout to avoid the vanishing gradient problem.

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.hidden1 = nn.Linear(784, 240)
        self.hidden2 = nn.Linear(240, 240)
        self.output = nn.Linear(240, 1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.leaky_relu(self.hidden1(x))
        x = self.leaky_relu(self.hidden2(x))
        x = self.sigmoid(self.output(x))
        return x