# Training a Generative Adversarial Network on MNIST

The goal of this assignment is to train a GAN (Generative Adversarial Network) model using PyTorch on the MNIST dataset to generate digit images. The MNIST dataset consists of grayscale handwritten digit images with a resolution of 28×28 pixels.

For this task, you are required to complete the following steps:

> 1. Implement the main components in the Generator and Discriminator. (in Step-2)
> 2. Define the loss function and optimizers (in Step 3).
> 3. Complete the training loop (in Step 4), including: \
    (a) Labels for real images and fake images when training the discriminator, \
    (b) Calculation of the discriminator’s loss for real images, \
    (c) Calculation of the discriminator’s loss for fake images, \
    (d) Calculation of the generator’s loss.
> 4. Plot the training loss curves (in Step 5).
> 5. Generate and save 25 images using the trained generator (in Step 6).

In [None]:
""" 
Step-1: import packages and load datasets
        For MNIST dataset, we can directly load it from torchvision
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL.Image as Image 

# Hyper-parameters
batch_size = 128
latent_dim = 100
lr = 0.0002
epochs = 50
device = "cpu"  # If your device has a GPU, you can change 'cpu' to 'cuda'.


# Create directories for saving the outputs
os.makedirs("gan_images", exist_ok=True)
os.makedirs("gan_models", exist_ok=True)

# Data processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [None]:
"""
Step 2: implement the generator and discriminator
"""

# Define the Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim=100,):
        super(Generator, self).__init__()
        
        ############# Start of your code ###################
        # Note: You need to insert activation functions at appropriate locations
        # TODO-1
        self.model = nn.Sequential(
            
        )
        
        ############## End of your code ####################
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


# Define the Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        ############# Start of your code ###################
        # Note: You need to insert activation functions at appropriate locations
        # TODO-1
        self.model = nn.Sequential(
            
        )
        
        ############## End of your code ####################
        
    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        validity = self.model(flattened)
        return validity.squeeze(1)


generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
"""
Step 3: define the loss functions and optimizers for the generator and discriminator
"""

############# Start of your code ###################
# TODO-2

adv_loss_fun = None
optimizer_G = None
optimizer_D = None

############## End of your code ####################


In [None]:
"""
Step 4: implement entire process of training a GAN
"""

d_losses = []
g_losses = []

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        
        ############# Start of your code ###################
        # define labels for real images and fake images
        # TODO-3

        real = None
        fake = None 
        
        ############## End of your code ####################        
        
        real_imgs = imgs.to(device)
        
        
        # Training the discriminator
        ############# Start of your code ###################
        # freeze the parameters in the generator and activate the parameters in the discriminator 
        # TODO-3

        ############## End of your code ####################


        optimizer_D.zero_grad()

        ############# Start of your code ###################
        # Loss of discriminting real images
        # TODO-3
        real_loss = None 
        ############## End of your code ####################
        
        
        # Generating fake images
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)


        ############# Start of your code ###################
        # Loss of discriminting generated images
        # TODO-3
        fake_loss = None 
        ############## End of your code ####################
        
        # Overall loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()


        
        # Training the generator

        ############# Start of your code ###################
        # freeze the parameters in the discriminator and activate the parameters in the generator 
        # TODO-3

        ############## End of your code ####################

        optimizer_G.zero_grad()
        
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)  # Generating fake images

        ############# Start of your code ###################
        # Loss of generating images
        # TODO-3
        g_loss = None 
        ############## End of your code ####################

        g_loss.backward()
        optimizer_G.step()
        
        # Recording the training process
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
                    f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
    
    # Save the generated images at the end of each epoch.
    save_image(gen_imgs.data[:25], f"gan_images/epoch_{epoch}.png", nrow=5, normalize=True)
    
    # Save the checkpoint
    torch.save(generator.state_dict(), f"gan_models/generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"gan_models/discriminator_epoch_{epoch}.pth")


In [None]:
"""
Step 5: Plot the training loss curve
"""

############# Start of your code ###################
# Plot the training loss curve
# TODO-4

############## End of your code ####################

In [None]:
"""
Step 6: Generate images using a trained generator, and save them
"""

############# Start of your code ###################
# TODO-5 
num_examples=25


############## End of your code ####################
