# GAN on MNIST

In [1]:
'''
Loading necessary libraries
'''
import torch
import torch.nn as nn
import torchvision
from torchvision.utils import save_image
import torchvision.transforms as transforms

In [19]:
'''
Setup parameters
'''
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_epochs = 50
n_classes = 10
batch_size = 100
lr = 2e-4

latent_size = 64
hidden_size = 256
image_size=784

In [20]:
'''
Loading MNIST dataset
'''
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5], std=[0.5])]) # [0.5,0.5,0.5] for color image

# Datasets
train_dataset = torchvision.datasets.MNIST(root='./data/mnist', 
                                           train=True,
                                           download=True, 
                                           transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./data/mnist', 
                                          train=False,
                                          download=True, 
                                          transform=transform)

# Loaders
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True, 
                                           num_workers=12)

test_loader = torch.utils.data.DataLoader(test_dataset, 
                                          batch_size=batch_size,
                                          shuffle=False, 
                                          num_workers=12)

In [21]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [22]:
'''
Define model
'''
Discriminator = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
    
Generator = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

Discriminator = Discriminator.to(device)
Generator = Generator.to(device)

In [23]:
'''
Optimizer and Loss function
'''
# Binary cross entropy 
loss_fn = nn.BCELoss()

optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=lr)
optimizer_g = torch.optim.Adam(Generator.parameters(), lr=lr)

In [24]:
'''
Train the model
'''
# def train_discriminator():
# def train_generator():

total_steps = len(train_loader)
for epoch in range(n_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create labels for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # Train Discriminator
        # Compute BCE loss using real images
        output = Discriminator(images)
        loss_d_real = loss_fn(output, real_labels)
        real_score = output
        
        # Compute BCE lossusing fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = Generator(z)
        output = Discriminator(fake_images)
        loss_d_fake = loss_fn(output, fake_labels)
        fake_score = output
        
        # backprop
        loss_d = loss_d_real + loss_d_fake
        optimizer_d.zero_grad()
        optimizer_g.zero_grad()
        loss_d.backward()
        optimizer_d.step()
        
        # Train Generator
        # Compute loss for fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = Generator(z)
        output = Discriminator(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        loss_g = loss_fn(output, real_labels)
        
        # backprop
        optimizer_d.zero_grad()
        optimizer_g.zero_grad()
        loss_g.backward()
        optimizer_g.step()
        
        if (i+1)%200==0:
            print('Epoch [{}/{}]-[{}/{}], D-loss: {:.4f}, G-loss: {:.4f}, D(x):{:.2f}, D(G(z)):{:.4f}'
                 .format(epoch+1, n_epochs, i+1, total_steps, loss_d.item(), loss_g.item(),
                        real_score.mean().item(), fake_score.mean().item()))
    # Save fake images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), './fake_images/fake-image-'+str(epoch+1)+'.png')
        
        

Epoch [1/50]-[200/600], D-loss: 0.0402, G-loss: 4.2268, D(x):0.99, D(G(z)):0.0341
Epoch [1/50]-[400/600], D-loss: 0.0494, G-loss: 7.4043, D(x):0.98, D(G(z)):0.0161
Epoch [1/50]-[600/600], D-loss: 0.0756, G-loss: 4.3754, D(x):0.98, D(G(z)):0.0562
Epoch [2/50]-[200/600], D-loss: 0.0831, G-loss: 4.7943, D(x):0.97, D(G(z)):0.0471
Epoch [2/50]-[400/600], D-loss: 0.4232, G-loss: 3.5036, D(x):0.88, D(G(z)):0.1753
Epoch [2/50]-[600/600], D-loss: 0.7741, G-loss: 3.2409, D(x):0.85, D(G(z)):0.3315
Epoch [3/50]-[200/600], D-loss: 0.0668, G-loss: 4.7927, D(x):0.96, D(G(z)):0.0243
Epoch [3/50]-[400/600], D-loss: 0.8147, G-loss: 2.0243, D(x):0.69, D(G(z)):0.2021
Epoch [3/50]-[600/600], D-loss: 0.5705, G-loss: 2.7803, D(x):0.75, D(G(z)):0.1497
Epoch [4/50]-[200/600], D-loss: 0.2388, G-loss: 3.3218, D(x):0.91, D(G(z)):0.0768
Epoch [4/50]-[400/600], D-loss: 0.6129, G-loss: 2.0436, D(x):0.85, D(G(z)):0.2534
Epoch [4/50]-[600/600], D-loss: 0.8376, G-loss: 2.4268, D(x):0.76, D(G(z)):0.2542
Epoch [5/50]-[20