# MLRS2 - Exercise GAN

Please follow this notebook and fill missing parts based no the description in the README.md.

In [None]:
# Install dependencies (only when using Google Colab)
!pip install torch torchvision matplotlib tensorboard torchsummary array2gif jupyter tqdm

In [None]:
import os
import numpy as np
import time
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt

from PIL import Image
from io import BytesIO

from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchsummary import summary
from array2gif import write_gif

%load_ext tensorboard

In [None]:
def show_image_batch(batch, name=""):
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"{name} images")
    plt.imshow(np.transpose(vutils.make_grid(batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

# Function to convert a tensor to an image plot
def tensor_to_plot_image(tensor):
    tensor = tensor.cpu().numpy()
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.plot(tensor)
    
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    image = Image.open(buf)
    image = np.array(image)
    return image

In [None]:
n_epochs = 100
batch_size = 64
learning_rate = 0.0002
momentum = 0.5
z_dim = 100
image_size = (28, 28)
n_conv = 64
real_label = 1
fake_label = 0

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs("./ckpts", exist_ok=True)

In [None]:
# Download MNIST dataset
mnist_dataset = MNIST('./data', train=True, download=True, transform=ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

### Building the Discriminator

Now, let's define the **discriminator** model. It acts as a binary classifier to distinguish real MNIST digits from the fake ones generated by the generator. Pay attention to the use of convolutional layers, `LeakyReLU`, and the final sigmoid activation.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, n_conv):
        super(Discriminator, self).__init__()
        ################## TODO ##################
        # Disciminator layers:                   #
        #    - 3 conv layers with 64 filters, a  #
        #      kernelsize of 3 and a stride of 2 #
        #      and a padding of 1                # 
        #    - The last conv layer has 1 output  #
        #      channel                           #
        #    - 1 conv layer with 1 input and 1   #
        #      output channel, a kernelsize of 3 #
        #      and a padding of 0.               #
        ##########################################
        
    # forward method
    def forward(self, x):
        ################## TODO ##################
        # Discriminator forward pass:            #
        #     - Call the layers from the con-    #
        #       structor                         #
        #     - Leaky ReLU activations           #
        #     - After the 4. conv layer: expand  #
        #       dimensions from [N, F] to        #
        #       [N, F, 1, 1]                     #
        #     - Sigmoid output activation        #
        ##########################################
        return x

    def save_checkpoint(self, optimizer, filepath):
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, filepath)
        #print(f"Discriminator checkpoint saved to {filepath}")        

    def load_checkpoint(self, optimizer, filepath):
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        #print(f"Discriminator checkpoint loaded from {filepath}")


  

### Building the Generator

Next, we'll design the **generator** model. This model transforms a latent vector (sampled from a Gaussian distribution) into a 28×28 grayscale image. We'll use transposed convolutions to upscale the feature maps.

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, n_conv):
        super(Generator, self).__init__()

        self.n_conv = n_conv
        ################## TODO ##################
        # Generator layers:                      #
        #    - A fully connected layer with z_dim#
        #      input nodes and feature_map_size_x#
        #      feature_map_size_y *              #
        #      feature_map_size_z * n_conv output#
        #      nodes                             #
        #    - 3 transposed conv layers with 64  #
        #      filters, a kernelsize of 3 and a  #
        #      stride of 2, and a padding of 1.  # 
        #    - The last conv layer has 1 output  #
        #      channel                           #
        #    - The number of feature maps is     #
        #      doubled in each subsequent layer  #
        #    - 1 conv layer with 1 input and 1   #
        #      output channel, a kernelsize of 3 #
        #      and a padding of 0.               #
        ##########################################
        
    def forward(self, x):
            def forward(self, x):
        ################## TODO ######################
        # Discriminator forward pass:                #
        #     - Call the layers from the con-        #
        #       structor                             #
        #     - Leaky ReLU activations               #
        #     - Use the fully connected layer to     #
        #       have the correct amount of neuros    #
        #     - Reshape the output of the first      #
        #       fully connected layer to             #
        #       [-1, n_conv * feature_mapsize_x,     #
        #       feature_mapsize_y, feature_mapsize_z]#
        #     - Call the layers from the con-        #
        #       structor                             #
        #     - Leaky ReLU activations               #
        #     - Sigmoid output activation            #                
        ##############################################
        
        return x

    def save_checkpoint(self, optimizer, filepath):
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, filepath)
        
        #print(f"Generator checkpoint saved to {filepath}")        

    def load_checkpoint(self, optimizer, filepath):
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        #print(f"Generator checkpoint loaded from {filepath}")

In [None]:
# Initialize D and G
D = Discriminator(n_conv=n_conv).to(device)
G = Generator(z_dim=z_dim, n_conv=n_conv).to(device)

summary(D, input_size=(1, 28, 28))
summary(G, input_size=(1, 100))

In [None]:
# Prepare some data samples and create fake images
x_real = next(iter(train_loader))[0]
z_noise = torch.randn(batch_size, z_dim, device=device)

Dout = D(x_real.to(device))
x_fake = G(z_noise)

In [None]:
# Show real images
show_image_batch((1 - x_real), name="real")

In [None]:
# Transform 1D plots to images
plot_images = [tensor_to_plot_image(z_noise[i]) for i in range(len(z_noise))]
plot_images = torch.tensor(plot_images, dtype=torch.float32)[..., :-1]
plot_images /= 255.
plot_images = torch.permute(plot_images, (0, 3, 1, 2))

In [None]:
# Show latent vector and corresponding generated fake image from the generator
show_image_batch(plot_images, name="z")
show_image_batch(x_fake, name="fake")

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./runs

In [None]:
# Setup Loss and Adam optimizers for both G and D
loss_fn = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=learning_rate, betas=(momentum, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=learning_rate, betas=(momentum, 0.999))

In [None]:
writer = SummaryWriter()
img_list = []
total_iters = 0

### Adversarial Training Loop

Time to bring both networks together! Implement the mini-max training loop where the discriminator and generator train in an adversarial fashion. Follow the three-step process: 
1. Train discriminator on real data
2. Train discriminator on fake data
3. Train generator to fool the discriminator

Make sure to monitor losses for both networks during training.

In [None]:
start_time = time.time()
for epoch in range(n_epochs):
    loop = tqdm(enumerate(train_loader, 0), total=len(train_loader), desc=f"Epoch {epoch+1}/{n_epochs}", leave=False)
    for i, data in loop:
        ######################### TODO #########################
        # (1) Update the discriminator with real data          #
        #     - Zero D's gradients                             #
        #     - Create real labels (1): create a tensor holding#
        #       the true labels [1, 1, 1, ..., 1]              #
        #     - D's forward pass                               #
        #     - calculate loss using D's output and real labels#
        #     - Backward pass for D  for real labels           #
        ########################################################

        ######################### TODO #########################
        # (2) Update the discriminator with fake data          #
        #     - Generate batch of latent vectors               #
        #     - Generate fake image based on latent vector     #
        #     - Classify all fake image with D                 #
        #     - Calculate D's loss on the all-fake batch using #
        #       fake labels ([0, 0, 0, ..., 0])                #
        #     - Calculate G's loss using outputs and fake      #
        #       labels                                         #
        #     - Backward pass for G's loss                     #
        #     - Add D's and G's losses                         #
        #     - Backward pass for D  for fake labels           #
        #     - D optimizer step                               #
        ########################################################

        ######################### TODO #########################
        # (3) Update the generator with fake data              #
        #     - Zero G's gradients                             #
        #     - Perform another forward pass of all-fake batch #
        #       through D                                      #
        #     - Create  fake labels are real for generator cost#
        #       [1, 1, 1, ..., 1]                              #
        #     - calculate G's loss                             #
        #     - Backward pass for G' loss                      #
        #     - G optimizer step                               #
        ########################################################

        
        # Output training stats
        if i % 50 == 0:
            
            writer.add_scalar("Losses/G_loss", loss_G.item(), total_iters)
            writer.add_scalar("Losses/D_loss", loss_D.item(), total_iters)

        loop.set_postfix({
            'Loss_D': f'{loss_D.item():.4f}',
            'Loss_G': f'{loss_G.item():.4f}',
            'D(x)': f'{D_x:.4f}',
            'D(G(z))': f'{D_G_z1:.4f}/{D_G_z2:.4f}'
        })
        # Check how the generator is doing by saving G's output on fixed_noise
        if (total_iters % 50 == 0) or ((epoch == n_epochs-1) and (i == len(train_loader)-1)):
            with torch.no_grad():
                x_fake = G(z_noise).detach().cpu()
            img_list.append(vutils.make_grid(x_fake, padding=2, normalize=True))
        
        total_iters += 1
    D.save_checkpoint(optimizerD, "./ckpts/discriminator.ckpt")
    G.save_checkpoint(optimizerG, "./ckpts/generator.ckpt")
elapsed_time = time.time() - start_time
print("Training time (sec): ", elapsed_time)

In [None]:
D.load_checkpoint(optimizerD, "./ckpts/discriminator.ckpt")
G.load_checkpoint(optimizerG, "./ckpts/generator.ckpt")

In [None]:
with torch.no_grad():
    x_fake = G(z_noise).detach().cpu()

show_image_batch((1. - x_fake), name="fake")

In [None]:
# Generate and safe GIF from generated images from G
img_list_arr = np.array(img_list)
img_list_arr = (1. - img_list_arr)
img_list_arr *= 255
img_list_arr = img_list_arr.astype(np.uint8)

write_gif(img_list_arr, 'imgs_G.gif', fps=50)

## Summary

Great job completing the exercise notebook! The resulting GIF shows how the generator learns over time how to generate number images: [imgs_G.gif](imgs_G.gif)


In [5]:
import time
start_time = time.time()
time.sleep(1)
elapsed_time = time.time() - start_time
print("Training time: ", elapsed_time)

Training time:  1.0011022090911865
