#### Deep Convolutional GANs

original [paper](https://arxiv.org/pdf/1511.06434)

In [3]:
%%HTML
<style>
    body {
        --vscode-font-family: "Sherif",;
    }
</style>

in this notebook we'll try to train a DCGAN which is essentially an unsupervised convolution neural network
(CNN) model.

Both the generator and the discriminator in a DCGAN are purely CNNs with no fully
connected layers. 

in high-level, For any GAN that is used to generate some kind of real data, the generator usually takes random noise
as input and produces an output with the same dimensions as the real data. We call this generated
output fake data. The discriminator, on the other hand, works as a binary classifier. It takes in the
generated fake data and the real data (one at a time) as input and predicts whether the input data is
real or fake.

<img src=../images/GAN-arch.png width=750 style='display:block; margin:auto;'>

The discriminator network is optimized like any binary classifier, that is, using the binary cross-entropy
function. Therefore, the discriminator model’s motivation is to correctly classify real images as real and
fake images as fake. The generator network has quite the opposite motivation. the generator loss is expressed as $-log(D(G(x)))$ where x is a random noise inputted into the generator model $G$. $G(x)$ is the generated fake image by the generator model; and $D(G(x))$ is the output probability of
the discriminator model, $D$ – that is, the probability of the image being real.

#### Joint optimizaiton:

In execution, these two loss functions are backpropagated alternatively. That is, at every iteration of
training, first, the discriminator is frozen, and the parameters of the generator networks are optimized
by backpropagating the gradients from the generator loss. Then, the tuned generator is frozen while the discriminator is optimized by backpropagating the gradients from the discriminator loss. This is what we call joint optimization. It has also been referred to as being equivalent to a two-player **Minimax** game in the original GAN paper.

<img src=../images/generator_arch.png width=900 style='display:block; margin:auto;'>

**Upsampling** in CNNs refers to the process of increasing the spatial resolution of feature maps by inserting additional rows and columns of zeros or by using interpolation methods, such as **bilinear** or
**nearest-neighbor interpolation**. This is commonly used in tasks such as image segmentation, where
the final output needs to have the same spatial dimensions as the input image.

<img src=../images/discriminator_arch.png width=900 style="display:block; margin:auto">

a stride of 2 at every convolutional layer in this architecture helps to reduce the spatial dimension, while the depth (that is, the number of feature maps) keeps growing. This is a classic
CNN-based binary classification architecture being used here to classify between real images and
generated fake images.

from now on, we will build, train, and test a DCGAN model using PyTorch in the form of an exercise.
We will use an image dataset to train the model and test how well the generator of the trained DCGAN
model performs when producing fake images

In [None]:
# !pip install torch==2.2
# !pip install torchvision==0.17
# !pip install matplotlib==3.5.2
!pip install scikit-image==0.19.3

In [None]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable


import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

In [None]:
# define constants and hyperparameters
num_eps=10
bsize=32
lrate=0.001
# length of the random noise vector, which essentially means that we will draw the 
# random noise from a 64-dimensional latent space as input to the generator model
lat_dimension=64
image_sz=64
chnls=1
logging_intv=200

In [None]:
class GANGenerator(nn.Module):
    def __init__(self):
        super(GANGenerator, self).__init__()
        self.inp_sz = image_sz // 4
        self.lin = nn.Linear(lat_dimension, 128 * self.inp_sz ** 2) # project latent vector to feature map
        self.bn1 = nn.BatchNorm2d(128)
        self.up1 = nn.Upsample(scale_factor=2)
        self.cn1 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128, 0.8)
        self.rl1 = nn.LeakyReLU(0.2, inplace=True)
        self.up2 = nn.Upsample(scale_factor=2)
        self.cn2 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64, 0.8)
        self.rl2 = nn.LeakyReLU(0.2, inplace=True)
        self.cn3 = nn.Conv2d(64, chnls, 3, stride=1, padding=1)
        self.act = nn.Tanh() # [-1, 1] range for pixel values

    def forward(self, x):
        x = self.lin(x)
        x = x.view(x.shape[0], 128, self.inp_sz, self.inp_sz)
        x = self.bn1(x)
        x = self.up1(x)
        x = self.cn1(x)
        x = self.bn2(x)
        x = self.rl1(x)
        x = self.up2(x)
        x = self.cn2(x)
        x = self.bn3(x)
        x = self.rl2(x)
        x = self.cn3(x)
        out = self.act(x)
        return out

In [None]:
class GANDiscriminator(nn.Module):
    def __init__(self):
        super(GANDiscriminator, self).__init__()

        def disc_module(ip_chnls, op_chnls, bnorm=True):
            mod = [nn.Conv2d(ip_chnls, op_chnls, 3, 2, 1), 
                   nn.LeakyReLU(0.2, inplace=True), 
                   nn.Dropout2d(0.25)]
            if bnorm:
                mod += [nn.BatchNorm2d(op_chnls, 0.8)]
            return mod

        self.disc_model = nn.Sequential(
            *disc_module(chnls, 16, bnorm=False),
            *disc_module(16, 32),
            *disc_module(32, 64),
            *disc_module(64, 128),
        )

        # width and height of the down-sized image
        ds_size = image_sz // 2 ** 4   # 4 conv layer with down-sampling
        self.adverse_lyr = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        

    def forward(self, x):
        x = self.disc_model(x)
        x = x.view(x.shape[0], -1) # flatten the output of conv layers
        out = self.adverse_lyr(x)
        return out

In [None]:
# instantiate the discriminator and generator models
gen = GANGenerator()
disc = GANDiscriminator()
# define the loss function
adv_loss_func = torch.nn.BCELoss()

#### Loading image dataset

In [None]:
dloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root='./data/MNIST/',
        train=True,
        transform=transforms.Compose([
            transforms.Resize(image_sz, image_sz),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]),
        download=True
    ),
    batch_size=bsize,
    shuffle=True
)

# define optimization schedule for both G and D
opt_gen = torch.optim.Adam(gen.parameters(), lr=lrate)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lrate)

#### Training loop

In [None]:
os.makedirs("./images_mnist", exist_ok=True)

for ep in range(num_eps):
    for idx, (images, _) in enumerate(dloader):

        # generate ground truths for real and fake images
        good_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(1.0), requires_grad=False)
        bad_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(0.0), requires_grad=False)

        # get a real image
        actual_images = Variable(images.type(torch.FloatTensor))

        # train the generator model
        opt_gen.zero_grad()

        # generate a batch of images based on random noise as input
        noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (images.shape[0], lat_dimension))))
        gen_images = gen(noise)

        # generator model optimization - how well can it fool the discriminator
        generator_loss = adv_loss_func(disc(gen_images), good_img)
        generator_loss.backward()
        opt_gen.step()

        # train the discriminator model
        opt_disc.zero_grad()

        # calculate discriminator loss as average of mistakes(losses) in confusing real images as fake and vice versa
        actual_image_loss = adv_loss_func(disc(actual_images), good_img)
        fake_image_loss = adv_loss_func(disc(gen_images.detach()), bad_img)
        discriminator_loss = (actual_image_loss + fake_image_loss) / 2

        # discriminator model optimization
        discriminator_loss.backward()
        opt_disc.step()

        batches_completed = ep * len(dloader) + idx
        if batches_completed % logging_intv == 0:
            print(f"epoch number {ep} | batch number {idx} | generator loss = {generator_loss.item()} | discriminator loss = {discriminator_loss.item()}")
            save_image(gen_images.data[:25], f"images_mnist/{batches_completed}.png", nrow=5, normalize=True)