In [None]:
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt

In [None]:
# Hyperparameters

epochs = 100
batch_size = 32
learning_rate = 0.001
latent_size = 100
image_size = 784

In [None]:
class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        
        # Layers
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.output = nn.Linear(512, output_size)
        
        # Activations
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.output(out)
        out = self.sigmoid(out)
        
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator, self).__init__()
        self.input_size = input_size
        
        # Layers
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.output = nn.Linear(512, 1)
        
        # Activations
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.output(out)
        out = self.sigmoid(out)
        
        return out

In [None]:
def get_latent(shape):
    z = torch.from_numpy(np.random.normal(size=shape)).float()
        
    if torch.cuda.is_available():
        z = z.cuda()
    
    return z

In [None]:
# MNIST dataset 
dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(),  download=True)

# Data loader
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# Labels
one = torch.from_numpy(np.ones(shape=[batch_size, 1])).float()
zero = torch.from_numpy(np.zeros(shape=[batch_size, 1])).float()

In [None]:
# Models
generator = Generator(latent_size, image_size)
discriminator = Discriminator(image_size)

In [None]:
# Loss and optimizer
loss = nn.BCELoss()

optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

In [None]:
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    loss = loss.cuda()
    one = one.cuda()
    zero = zero.cuda()

In [None]:
for epoch in tqdm(range(epochs), desc="Epoch"):
    for i, (image_real, label) in tqdm(enumerate(loader), desc="Example", total=int(len(dataset)/batch_size)):
        
        image_real = image_real.view(-1, image_size)
        
        if torch.cuda.is_available():
            image_real = image_real.cuda()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        for p in discriminator.parameters():
            p.requires_grad = True
            
        optimizer_d.zero_grad()
        
        # Discriminator sees true image
        prob_real = discriminator(image_real)
        loss_real = loss(prob_real, one)
        
        # Generator makes an image
        z = get_latent([batch_size, latent_size])
            
        image_fake = generator(z)
        
        # Discriminator sees generated image
        prob_fake = discriminator(image_fake)
        loss_fake = loss(prob_fake, zero)
        
        loss_d = 0.5*(loss_real + loss_fake)
        loss_d.backward()
        
        optimizer_d.step()
        
        # -----------------
        #  Train Generator
        # -----------------
        
        for p in discriminator.parameters():
            p.requires_grad = False
        
        optimizer_g.zero_grad()
        
        # Generator makes another image
        z = get_latent([batch_size, latent_size])

        image_fake = generator(z)
        
        # Discriminator determines the probability it is a true image
        prob_true = discriminator(image_fake)
        loss_g = loss(prob_true, one)
        loss_g.backward()
        
        optimizer_g.step()