In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )
  
  def forward(self, x):
    return self.gen(x)

In [None]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

In [None]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

# set some fixed noise to visualize how it changes across the epochs
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]
)


In [None]:
# create our dataset using MNIST dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

# creating our Data Loader
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Creating our optimizers for both the generator and discriminator
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

# Initializing our Loss
criterion = nn.BCELoss()

# Summary writer
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

step = 0

In [None]:
# running our for loop

for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0] # checks for the size of the image

    ### Train Discriminator: max log(D(real)) + log(1-D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake.detach()).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    ### Train Generator min log(1 - D(G(z))) <--> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] \ " 
          f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                                                    
        writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step = step
        )

        writer_real.add_image(
            "Mnist real Images", img_grid_real, global_step = step
        )

        step += 1

Epoch [0/50] \ Loss D: 0.6573, Loss G: 0.7609
Epoch [1/50] \ Loss D: 0.1260, Loss G: 2.1976
Epoch [2/50] \ Loss D: 0.1183, Loss G: 2.7656
Epoch [3/50] \ Loss D: 0.0713, Loss G: 3.7644
Epoch [4/50] \ Loss D: 0.0389, Loss G: 3.6778
Epoch [5/50] \ Loss D: 0.0424, Loss G: 4.7287
Epoch [6/50] \ Loss D: 0.0810, Loss G: 3.3799
Epoch [7/50] \ Loss D: 0.0121, Loss G: 4.7322
Epoch [8/50] \ Loss D: 0.0166, Loss G: 4.4796
Epoch [9/50] \ Loss D: 0.0142, Loss G: 6.2963
Epoch [10/50] \ Loss D: 0.0290, Loss G: 5.6075
Epoch [11/50] \ Loss D: 0.0300, Loss G: 5.4060
Epoch [12/50] \ Loss D: 0.0628, Loss G: 6.2847
Epoch [13/50] \ Loss D: 0.0056, Loss G: 6.6035
Epoch [14/50] \ Loss D: 0.0102, Loss G: 6.6843


In [None]:
# Proposed Improvements
# 1. Try using a larger network?
# 2. Better Normalization with BatchNorm
# 3. Better learning rate?
# 4. Change architecture to a CNN