<a href="https://colab.research.google.com/github/goyalpramod/paper_implementations/blob/main/GANs_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code implementation of the following paper -> [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661)

This work has been highly influenced by the following implemnetations

* [bamos/dcgan-completion.tensorflow](https://github.com/bamos/dcgan-completion.tensorflow)
* [soumith/dcgan.torch](https://github.com/soumith/dcgan.torch/blob/master/main.lua) [The original pytorch implementation, you can find the tf version in alec's github]
* [openai/improved-gan](https://github.com/openai/improved-gan)
* [lilianweng/unified-gan-tensorflow](https://github.dev/lilianweng/unified-gan-tensorflow) [I picked this one to follow (no reasons)]
* [carpedm20/DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) [Most of the implementation build on this]

Also check out the following blogs for more clarity on the topic
* [Brandon's Blog](https://bamos.github.io/2016/08/09/deep-completion/)
* [Lil'log's Blog](https://lilianweng.github.io/posts/2017-08-20-gan/)

# Lets understand some equations from the paper

$\min_G \max_D V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z\sim p_z(z)}[\log(1-D(G(z)))]$

$\min_G \max_D V(D,G)$

This notation means:

* The discriminator D tries to maximize V(D,G)
* The generator G tries to minimize V(D,G)

This creates an adversarial game between G and D

$\mathbb{E}_{x\sim p_{data}(x)}[\log D(x)]$

This represents:
* Expected value (𝔼) over samples x drawn from the real data distribution pdata(x)
* D(x) outputs a probability between 0 and 1 (how "real" D thinks x is)
* log D(x) rewards D for correctly identifying real samples
* When D(x)→1 for real data, log D(x)→0, maximizing the objective

$\mathbb{E}_{z\sim p_z(z)}[\log(1-D(G(z)))]$

This represents:

* Expected value over random noise z drawn from noise distribution pz(z)
* G(z) generates fake samples from noise
* D(G(z)) is the discriminator's output on generated samples
* log(1-D(G(z))) rewards D for correctly identifying fake samples
* When D(G(z))→0 for fake data, log(1-D(G(z)))→0, maximizing the objective

$D^*_G(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}$

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
# Any function you do not understand you can add a '?' and run to see its definition
# nn.ConvTranspose2d?
# What will happen if we do 'nn.ConvTranspose2d??', run and find out!!

In [3]:
class Generator(nn.Module):
    """
    Input: (batch_size, latent_dim)  # typically latent_dim = 100
    Output: (batch_size, channels, height, width)  # e.g., (batch_size, 3, 64, 64)

    Architecture Hints:
    1. Start with Dense layer to reshape latent vector
    2. Use multiple ConvTranspose2d layers to upscale
    3. Each layer should follow pattern: ConvTranspose2d -> BatchNorm2d -> ReLU
    4. Final layer should use Tanh activation
    """
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        # Step 1: Dense layer to reshape
        # latent_dim -> (4*4*512) features
        # Hint: Use Linear layer here

        # Step 2: Reshape to (batch_size, 512, 4, 4)

        # Step 3: Multiple ConvTranspose2d blocks
        # Block 1: (512, 4, 4) -> (256, 8, 8)
        # kernel_size=4, stride=2, padding=1

        # Block 2: (256, 8, 8) -> (128, 16, 16)
        # Same parameters as above

        # Block 3: (128, 16, 16) -> (64, 32, 32)

        # Final Block: (64, 32, 32) -> (3, 64, 64)
        # Don't forget Tanh() at the end!

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        # Dense layer
        self.linear = nn.Linear(latent_dim, 4*4*512)

        # Blocks
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Changed input channels from 3 to 1 for MNIST
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        # x shape: (batch_size, latent_dim)
        x = self.linear(x)
        x = x.view(-1, 512, 4, 4)  # reshape using view
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return x

In [5]:
class Discriminator(nn.Module):
    """
    Input: (batch_size, channels, height, width)  # e.g., (batch_size, 3, 64, 64)
    Output: (batch_size, 1)  # probability between 0-1

    Architecture Hints:
    1. Use Conv2d layers to downsample
    2. Each layer: Conv2d -> BatchNorm2d (except first) -> LeakyReLU
    3. Final layer should be Dense with Sigmoid
    """
    def __init__(self):
        super(Discriminator, self).__init__()

        # Block 1: (3, 64, 64) -> (64, 32, 32)
        # Conv2d: kernel_size=4, stride=2, padding=1
        # Don't use BatchNorm in first layer!

        # Block 2: (64, 32, 32) -> (128, 16, 16)
        # Remember BatchNorm here

        # Block 3: (128, 16, 16) -> (256, 8, 8)

        # Block 4: (256, 8, 8) -> (512, 4, 4)

        # Final: Flatten and Dense to single output
        # Don't forget Sigmoid()!

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # Changed input channels from 3 to 1 for MNIST
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.output = nn.Sequential(
            nn.Linear(512*4*4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = x.view(-1, 512*4*4)  # flatten
        x = self.output(x)
        return x

In [23]:
def train_step(real_images, generator, discriminator, criterion,
               optimizer_G, optimizer_D, latent_dim, device):
    batch_size = real_images.size(0)

    # 1. Train Discriminator
    optimizer_D.zero_grad()

    # 1a. Train on real batch
    # Hint: Create labels for real images (all ones)
    labels_real = torch.ones(batch_size, 1).to(device)
    # Hint: Forward real images through D
    d_output_real = discriminator(real_images)
    # Hint: Calculate loss for real images
    loss_real = criterion(d_output_real, labels_real)

    # 1b. Train on fake batch
    # Hint: Generate noise z
    z = torch.randn(batch_size, latent_dim).to(device)
    # Hint: Generate fake images using G(z)
    g_output = generator(z)
    # Hint: Create labels for fake images (all zeros)
    labels_fake = torch.zeros(batch_size,1).to(device)
    # Hint: Forward fake images through D
    d_output_fake = discriminator(g_output.detach())
    # Why: Get discriminator's prediction on fake images
    # Note: detach() prevents gradients flowing to generator
    # Doc: tensor.detach() - creates new tensor that shares storage but no gradients
    # Hint: Calculate loss for fake images
    loss_fake = criterion(d_output_fake,labels_fake)

    # 1c. Add both losses and update D
    # Hint: d_loss = d_loss_real + d_loss_fake
    d_loss = loss_real + loss_fake
    # Hint: backward and step
    d_loss.backward()
    optimizer_D.step()


    # 1. Clear generator gradients
    optimizer_G.zero_grad()

    # 2. Generate NEW fake images
    # Note: We can reuse the same z from before
    fake_images = generator(z)

    # 3. Get discriminator's prediction on fake images
    # Note: This time we DON'T use detach() because we want to train G
    d_output_fake = discriminator(fake_images)

    # 4. Calculate generator loss
    # Important: We use REAL labels (ones) here, not zeros!
    # Why? We want the generator to trick discriminator into thinking fakes are real
    g_loss = criterion(d_output_fake, labels_real)  # try to make D think these are real!

    # 5. Backpropagate and update
    g_loss.backward()
    optimizer_G.step()

    return d_loss.item(), g_loss.item()

In [24]:
import torchvision
import torchvision.transforms as transforms

# Define transforms
transform = transforms.Compose([
    transforms.Resize(64),  # Resize images to 64x64
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Load MNIST dataset
dataset = torchvision.datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Create dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,
                                       shuffle=True, num_workers=2)

In [25]:
# Hyperparameters
latent_dim = 100
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Initialize optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

criterion = nn.BCELoss()

In [None]:
# Lists to store losses for plotting
G_losses = []
D_losses = []

print("Starting Training...")

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        # Configure input
        real_images = real_images.to(device)

        # Train models
        d_loss, g_loss = train_step(real_images, generator, discriminator,
                                  criterion, optimizer_G, optimizer_D, latent_dim, device)

        # Save Losses for plotting later
        G_losses.append(g_loss)
        D_losses.append(d_loss)

        # Print progress
        if i % 100 == 0:
            print(f'[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] '
                  f'[D loss: {d_loss:.4f}] [G loss: {g_loss:.4f}]')

    # Generate and save sample images after each epoch
    if epoch % 5 == 0:
        with torch.no_grad():
            fake = generator(torch.randn(16, latent_dim, device=device))
            # Save the images (we'll implement this next)

Starting Training...
[Epoch 0/100] [Batch 0/938] [D loss: 1.4458] [G loss: 2.0241]
[Epoch 0/100] [Batch 100/938] [D loss: 0.0007] [G loss: 8.0421]
[Epoch 0/100] [Batch 200/938] [D loss: 0.0003] [G loss: 9.0054]
[Epoch 0/100] [Batch 300/938] [D loss: 100.0000] [G loss: 0.0000]
[Epoch 0/100] [Batch 400/938] [D loss: 0.7979] [G loss: 1.9469]
[Epoch 0/100] [Batch 500/938] [D loss: 1.3899] [G loss: 1.8981]
[Epoch 0/100] [Batch 600/938] [D loss: 0.5952] [G loss: 4.0113]
[Epoch 0/100] [Batch 700/938] [D loss: 0.3971] [G loss: 2.7752]
[Epoch 0/100] [Batch 800/938] [D loss: 0.7386] [G loss: 1.3431]
[Epoch 0/100] [Batch 900/938] [D loss: 0.4414] [G loss: 1.9620]
[Epoch 1/100] [Batch 0/938] [D loss: 0.9388] [G loss: 1.0397]
[Epoch 1/100] [Batch 100/938] [D loss: 0.7561] [G loss: 2.5330]
[Epoch 1/100] [Batch 200/938] [D loss: 0.1322] [G loss: 6.2223]
[Epoch 1/100] [Batch 300/938] [D loss: 0.3843] [G loss: 1.3346]
[Epoch 1/100] [Batch 400/938] [D loss: 0.2888] [G loss: 2.3259]
[Epoch 1/100] [Batch 

In [None]:
import matplotlib.pyplot as plt

def save_generator_output(fake_images, epoch):
    plt.figure(figsize=(10,10))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow(fake_images[i][0].cpu().detach().numpy(), cmap='gray')
        plt.axis('off')
    plt.savefig(f'fake_images_epoch_{epoch}.png')
    plt.close()