In [None]:
# Types of GAN
# 1. Basic GAN (Vanilla GAN) ----- Original GAN written by Ian Goodfellow (2014)

# Create a GAN model that can Generate a realistic handwritten digits.



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),             # Converts PIL image to tensor
    transforms.Normalize([0.5], [0.5]) # Normalize pixel values from [0,1] to [-1,1]
])


In [None]:
dataloader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=128,
    shuffle=True
)


Generator Network

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),   # Input: random noise (100 dims) → 256
            nn.ReLU(True),               # Activation for non-linearity
            nn.Linear(256, 512),         # 256 → 512
            nn.ReLU(True),
            nn.Linear(512, 1024),        # 512 → 1024
            nn.ReLU(True),
            nn.Linear(1024, 28*28),      # Final layer to match image size (784 pixels)
            nn.Tanh()                    # Output in range [-1, 1] to match normalized image
        )

    def forward(self, z):
        out = self.net(z)
        return out.view(-1, 1, 28, 28)  # Reshape to 28x28 grayscale image


 Discriminator Network

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),                # Flatten 28x28 image to 784
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),           # Leaky ReLU prevents dying neurons
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),           # Final layer outputs single score
            nn.Sigmoid()                 # Converts score to probability [0,1]
        )

    def forward(self, x):
        return self.net(x)


Model Initialization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
noise_dim = 100                    # Size of the noise vector input to Generator
G = Generator(noise_dim).to(device)
D = Discriminator().to(device)


In [None]:
criterion = nn.BCELoss()          # Binary Cross-Entropy for classification
optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)


In [None]:
epochs = 500

#Loop over each batch and epoch.
#_ ignores the labels since GANs are unsupervised.

for epoch in range(epochs):
    for real_images, _ in dataloader:
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Real and fake labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train on real images
        optimizer_D.zero_grad()
        outputs_real = D(real_images)           # Real image score
        loss_real = criterion(outputs_real, real_labels)

        # Generate fake images
        z = torch.randn(batch_size, noise_dim).to(device)
        fake_images = G(z)

        outputs_fake = D(fake_images.detach())  # Detach to avoid updating G
        loss_fake = criterion(outputs_fake, fake_labels)

        # Total discriminator loss
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()

        outputs = D(fake_images)                # Try to fool D with fake images
        loss_G = criterion(outputs, real_labels)  # G wants D to label them as real (1)
        loss_G.backward()
        optimizer_G.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

        if (epoch + 1) % 100 == 0:
          with torch.no_grad():  # Turn off gradients for inference
            z = torch.randn(64, noise_dim).to(device)
            fake = G(z).cpu()
            grid = fake.view(64, 1, 28, 28).detach().numpy()

            # Create a 8x8 grid of images
            fig, axs = plt.subplots(8, 8, figsize=(8, 8))
            for i in range(8):
                for j in range(8):
                    axs[i, j].imshow(grid[i*8+j][0], cmap='gray')
                    axs[i, j].axis('off')
            plt.show()








In [None]:
# Save the Trained Generator Model
torch.save(G.state_dict(), 'generator.pth')
#This saves only the weights (not the full model) to generator.pth.

In [None]:
#Load the Trained Generator : Rebuild the Generator architecture and load the weights

# Same Generator class must be defined
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.net(z)
        return out.view(-1, 1, 28, 28)

# Initialize and load the model
noise_dim = 100
G = Generator(noise_dim)
G.load_state_dict(torch.load('generator.pth'))
G.eval()  # Set to evaluation mode



In [None]:
#Generate a Single Fake Image

import matplotlib.pyplot as plt
import torch

# Generate random noise
z = torch.randn(1, noise_dim)  # One image

# Generate image
with torch.no_grad():
    fake_image = G(z)

# Convert to numpy and visualize
image_np = fake_image.squeeze().numpy()  # remove batch & channel dims

plt.imshow(image_np, cmap='gray')
plt.title("Generated Fake Image")
plt.axis('off')
plt.show()
