## Import Dependencies

In [1]:
import torch
from simple_gan import Generator,Discriminator
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter

# Simple GAN on MNIST Data

## Setup hyperparams

In [2]:
# Initial GANs are sensitive to Hyperparams
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4  # best lr for Adam given by andrew karpathy
z_dimension = 64 #128,256
image_dim = 28*28*1
batch_size = 32
num_epochs = 500

## Initialize GAN and setup Dataset

In [3]:
# Initialize Generator and Discriminator
discriminator = Discriminator(image_dimension=image_dim).to(device=device)
generator = Generator(z_dimension=z_dimension,image_dimension=image_dim).to(device=device)
fixed_noise = torch.randn((batch_size,z_dimension)).to(device=device)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))  # Actual mean and std for MNIST Dataset: transforms.Normalize((0.1307,),(0.3081,))
]) # Normalize image values to [-1,1] 
dataset = datasets.MNIST(root='./data/',transform=transform,download=True)
loader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
optim_disc = optim.Adam(discriminator.parameters(),lr=lr)
optim_gen = optim.Adam(generator.parameters(),lr=lr)
criterion = nn.BCELoss()


In [4]:
writer_fake = SummaryWriter(log_dir=f'./runs/GAN_MNIST/fake')  # fake images
writer_real = SummaryWriter(log_dir=f'./runs/GAN_MNIST/real')  # real images
step = 0

## Training

In [5]:
for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]

        # training the discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size,z_dimension).to(device)
        fake = generator(noise)
        disc_real = discriminator(real).view(-1)
        # ℓ(x,y)=L={l1,…,lN}⊤,  ln​ = −wn​ [yn​ ⋅log xn​ +(1−yn​ )⋅log(1−xn​ )], second term canceled out
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))
        disc_fake = discriminator(fake).view(-1)
        # ℓ(x,y)=L={l1,…,lN}⊤,  ln​ = −wn​ [yn​ ⋅log xn​ +(1−yn​ )⋅log(1−xn​ )], first term canceled out
        lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_fake + lossD_real)/2
        discriminator.zero_grad()
        lossD.backward(retain_graph = True)
        optim_disc.step()

        # training the generator: min log(1 - D(G(z))) -> max log(D(G(z)))
        output = discriminator(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        generator.zero_grad()
        lossG.backward()
        optim_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1


Epoch [0/500] Batch 0/1875                       Loss D: 0.6765, loss G: 0.6590
Epoch [1/500] Batch 0/1875                       Loss D: 0.3945, loss G: 1.2289
Epoch [2/500] Batch 0/1875                       Loss D: 0.2485, loss G: 1.9642
Epoch [3/500] Batch 0/1875                       Loss D: 0.2269, loss G: 1.9172
Epoch [4/500] Batch 0/1875                       Loss D: 0.5206, loss G: 1.1317
Epoch [5/500] Batch 0/1875                       Loss D: 0.2643, loss G: 1.7039
Epoch [6/500] Batch 0/1875                       Loss D: 0.4710, loss G: 1.1259
Epoch [7/500] Batch 0/1875                       Loss D: 0.5203, loss G: 1.1240
Epoch [8/500] Batch 0/1875                       Loss D: 0.4776, loss G: 1.3763
Epoch [9/500] Batch 0/1875                       Loss D: 0.3830, loss G: 1.3702
Epoch [10/500] Batch 0/1875                       Loss D: 0.5507, loss G: 1.6203
Epoch [11/500] Batch 0/1875                       Loss D: 0.6904, loss G: 1.0627
Epoch [12/500] Batch 0/1875           

KeyboardInterrupt: 

## Changes that may improve results

As we know that simple GANs are sensitive to hyperparameters, there are some changes that may significantly affect model performance:

1) Using larger network
2) Better normalization with BatchNorm
3) Different learning rate
4) Change of architecture to DCGAN