<a href="https://colab.research.google.com/github/keshsri/generative-adversarial-networks/blob/main/Generative_Adversarial_Networks_MNIST_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7cf2503b9d50>

In [None]:
def get_generator_block(input_dim, output_dim):
  return nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.BatchNorm1d(output_dim),
      nn.ReLU()
  )

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim=0, im_dim=784, hidden_dim=128):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        get_generator_block(z_dim, hidden_dim),
        get_generator_block(hidden_dim, hidden_dim * 2),
        get_generator_block(hidden_dim * 2, hidden_dim * 4),
        get_generator_block(hidden_dim * 4, hidden_dim * 8),
        nn.Linear(hidden_dim * 8, im_dim),
        nn.Sigmoid()
    )
  def forward(self, noise):
      return self.gen(noise)

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
  return torch.randn(n_samples, z_dim, device=device)

In [None]:
def get_discriminator_block(input_dim, output_dim):
  return nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.LeakyReLU(0.2)
  )

In [None]:
class Discriminator(nn.Module):
  def __init__(self, im_dim=784, hidden_dim=128):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        get_discriminator_block(im_dim, hidden_dim * 4),
        get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
        get_discriminator_block(hidden_dim * 2, hidden_dim),
        nn.Linear(hidden_dim, 1)
    )

  def forward(self, image):
    return self.disc(image)

In [None]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 10
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True
)

device='cuda'

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [None]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
  fake_noise = get_noise(num_images, z_dim, device=device)
  fake = gen(fake_noise)
  disc_fake_pred = disc(fake.detach())
  disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
  disc_real_pred = disc(real)
  disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
  disc_loss = (disc_fake_loss + disc_real_loss) / 2
  return disc_loss

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False

for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_batch_size = len(real)
    real = real.view(cur_batch_size, -1).to(device)
    disc_opt.zero_grad()
    disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
    disc_loss.backward(retain_graph=True)
    disc_opt.step()

    gen_opt.zero_grad()
    gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
    gen_loss.backward()
    gen_opt.step()

    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    if cur_step % display_step == 0 and cur_step > 0:
        print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        '''show_tensor_images(fake)
        show_tensor_images(real)'''
        mean_generator_loss = 0
        mean_discriminator_loss = 0
    cur_step += 1

  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

Step 500: Generator loss: 1.3880774766206745, discriminator loss: 0.419259883701801


  0%|          | 0/469 [00:00<?, ?it/s]

Step 1000: Generator loss: 1.6689664766788488, discriminator loss: 0.3002841672599312


  0%|          | 0/469 [00:00<?, ?it/s]

Step 1500: Generator loss: 1.9593542075157158, discriminator loss: 0.1741582162976264


  0%|          | 0/469 [00:00<?, ?it/s]

Step 2000: Generator loss: 1.7620599417686456, discriminator loss: 0.19567712315917024


  0%|          | 0/469 [00:00<?, ?it/s]

Step 2500: Generator loss: 1.7258853509426118, discriminator loss: 0.1949902396798135


  0%|          | 0/469 [00:00<?, ?it/s]

Step 3000: Generator loss: 2.0241021432876587, discriminator loss: 0.15115172801911828


  0%|          | 0/469 [00:00<?, ?it/s]

Step 3500: Generator loss: 2.403905700206759, discriminator loss: 0.1256361185610294


  0%|          | 0/469 [00:00<?, ?it/s]

Step 4000: Generator loss: 2.708461202621462, discriminator loss: 0.12470809704065329


  0%|          | 0/469 [00:00<?, ?it/s]

Step 4500: Generator loss: 3.1915947027206446, discriminator loss: 0.09926906828582288
