<a href="https://colab.research.google.com/github/ganeshbmc/GenAI_Math/blob/main/GAN_Vanilla.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Steps  
1. Import libraries  
2. Prepare data  
   ```Download  |  Transform  |  Dataloader```  
3. Define parameters  
   ```Model  |  Optimizer  |  Loss  |  Training  ```
4. Build Model  
   ```Components  ```
5. Training loop  
6. Visualize results  

## Import libraries  

In [5]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import os
import datetime

print(f"Imports completed at {datetime.datetime.now()}")

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Versions
print(f"Torch: {torch.__version__}, TorchVision: {torchvision.__version__}")

Imports completed at 2025-08-20 06:34:46.455147
Using device: cuda
Torch: 2.8.0+cu126, TorchVision: 0.23.0+cu126


## Define parameters  

In [60]:
# Set seed for PyTorch
seed = 42
torch.manual_seed(seed)

# Data prep params
batch_size = 128

# Model params
noise_dim = 100     # z_dim
img_dim = 28 * 28 # MNIST size

# Optimizer params
learning_rate = 0.0002

# # Loss params
# criterion = nn.BCELoss()

# Training params
num_epochs = 2
generator_rounds = 1
discriminator_rounds = 1

## Prepare data  

In [46]:
# Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
print(f"Transform to be applied:\n{transform}\n")

# Load MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
print(f"Train data:\n{train_dataset}\n")

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f"Data loader:\n{train_loader}")

Transform to be applied:
Compose(
    ToTensor()
    Normalize(mean=(0.5,), std=(0.5,))
)

Train data:
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

Data loader:
<torch.utils.data.dataloader.DataLoader object at 0x7b1a7e5c8500>


## Build Model  
```Vanilla GAN```  
- MLP as neural network  

Generator
- noise_dim -> hidden layers -> img_dim with ReLU and Tanh  

Discriminator
- img_dim -> hidden layers -> 1 with Sigmoid activation  

In [47]:
# Generator
class Generator(nn.Module):
  def __init__(self, noise_dim, img_dim):
    super(Generator, self).__init__()
    self.model = nn.Sequential(
        # MLP
        nn.Linear(noise_dim, 256),
        nn.ReLU(True),
        nn.Linear(256, 512),
        nn.ReLU(True),
        nn.Linear(512, 1024),
        nn.ReLU(True),
        nn.Linear(1024, img_dim),
        nn.Tanh()   # Because we normalized images to [-1, 1]
    )

  def forward(self, z):
    return self.model(z)

In [48]:
generator = Generator(noise_dim=noise_dim, img_dim=img_dim)
generator

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)

In [55]:
# Discriminator
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(img_dim, 512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(256, 1),
        nn.Sigmoid()    # Outputs probability between 0 and 1
    )

  def forward(self, img):
    o = self.model(img)
    return o

In [56]:
discriminator = Discriminator(img_dim)
discriminator

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

## Set up Optimizers  

In [57]:
g_optim = optim.Adam(generator.parameters(), lr=learning_rate)
g_optim

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0002
    maximize: False
    weight_decay: 0
)

In [58]:
d_optim = optim.Adam(discriminator.parameters(), lr=learning_rate)
d_optim

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0002
    maximize: False
    weight_decay: 0
)

## Set up Loss functions  

In [59]:
criterion = nn.BCELoss()
criterion

BCELoss()

## Code for visualizations  

In [None]:
def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake_imgs = generator(fixed_noise).reshape(-1, 1, 28, 28)
        fake_imgs = fake_imgs * 0.5 + 0.5  # De-normalize

    grid = torchvision.utils.make_grid(fake_imgs, nrow=8)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f'Generated Images at Epoch {epoch}')
    plt.axis('off')
    plt.show()
    generator.train()

## Training loop  
- Visualize results every 10 epochs  

In [None]:
def gan_trainer(train_loader, ):
  fixed_noise = torch.randn(64, noise_dim).to(device)  # For consistent visualization

  # Training loop
  for epoch in range(num_epochs):
    pass

  # Visualize
    if (epoch+1) % 10 == 0:
      print(f"Epoch: [{epoch+1}/{num_epochs}]")
      print(f"Discriminator loss = {d_loss.item():.4f}")
      print(f"Generator loss = {g_loss.item():.4f}")
      show_generated_images(epoch=epoch+1, generator=generator, fixed_noise=fixed_noise)