<div style="line-height:0.5">
<h1 style="color:#BF66F2 "> Generative Adversarial Network in PyTorch 1 </h1>
<h4> Example of GANs using the MNIST dataset with Binary Cross Entropy. </h4>

<span style="display: inline-block;">
    <h3 style="color: lightblue; display: inline;">Keywords:</h3>
    SummaryWriter + nn.BCELoss() + LeakyReLU
</span>
</div>

<h2 style="color:#BF66F2 "> GANs</h2>
<div style="margin-top: -25px;">



=> Train a generative model by framing the problem as a supervised learning problem with two sub-models:      
<div style="margin-top: -20px;">

+ Generator model trained to generate new realistic samples that can fool the critic.
+ Discriminator (or Critic) model that classify examples as either real (from the domain) or fake (generated), trained to distinguish <BR>between real and fake samples.

The generator and discriminator models are trained in an adversarial manner, (in a zero-sum game!)
<br> => they are trained to compete against each other.

The generator aims to generate realistic samples to deceive the discriminator, while the discriminator aims to correctly classify between real and fake samples, <br> 
with the critic trying to improve its ability to distinguish real from fake samples, and the generator trying to improve its ability <br> to generate realistic samples that can fool the discriminator.

This process continues iteratively until the generator can produce realistic samples that are indistinguishable from real data, and the discriminator <br> can no longer differentiate between real and fake samples with high confidence.
Namely it finishes when the discriminator model is fooled about <br>  half the time, meaning the generator model is generating plausible examples.   
</div>
</div>


<h3 style="color:#BF66F2 "> The GAN approach:</h3>
<div style="margin-top: -25px;">

Do not look for an explicit density model s describing the manifold of natural images...<br>
Just find out a model able to generate samples that «looks like» training samples.
<div style="margin-top: -18px;">

1. Sample a seed from a known distribution    
2. This is defined a priori and also referred to as noise   
3. Feed this seed to a learned transformation that generates realistic samples    
4. Use a neural network to learn this transformation the NN is going to be trained in an unsupervised manner, no label needed    

Final GOAL => create a regularization PRIOR in other problem to perform anomaly detection!
</div>
</div>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class Discriminator(nn.Module):
    """ Discriminator model for GAN.
    
    Attributes:
        Sequential layers of the generator [nn.Sequential]
    """
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),   # 128 output features
            nn.LeakyReLU(0.01),            # negative slope of 0.01
            nn.Linear(128, 1),             # 128 in => 1 ouput feature
            nn.Sigmoid(),
        )

    def forward(self, x):
        """ Forward pass of the generator. """
        return self.disc(x)


class Generator(nn.Module):
    """ Generator model for GAN.
    
    Attributes:
        Sequential layers of the discriminator [nn.Sequential].
    
    Details:
        final activation func Tanh() it is used to normalize inputs to [-1, 1] so make outputs [-1, 1].
    """
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        """ Forward pass of the discriminator. """
        return self.gen(x)

In [None]:
""" Hyperparameters """
lr = 3e-4                 #0.0003 (3 multiplied by 10 to the power of -4)
z_dim = 64
image_dim = 28 * 28 * 1   #total = 784
batch_size = 32
num_epochs = 50

<h2 style="color:#BF66F2 "> Initialize models and optimizers </h3>

In [None]:
# Create an instance of the Discriminator class with the specified input image dimension.
disc_obj = Discriminator(image_dim).to(device)
# Create an instance of the Generator class with input noise dimension (z_dim) and output image dimension (image_dim).
gen_obj = Generator(z_dim, image_dim).to(device)

# Generate fixed random noise tensor used during training for generating sample images, 
# initialized with random values drawn from a standard normal distribution.
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

### Define a sequence of image transformations to be applied to the training images:
    # Convert the input image to a PyTorch tensor
    # Normalize the pixel values to the range [-1, 1]
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

In [None]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 280865728.85it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 16387404.47it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 167386097.65it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 15047811.03it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [None]:
# Optimizers Adam
opt_disc = optim.Adam(disc_obj.parameters(), lr=lr)
opt_gen = optim.Adam(gen_obj.parameters(), lr=lr)

In [None]:
# Binary Cross Entropy loss between the target and the input probabilities.
criterion = nn.BCELoss()

The Binary Cross Entropy (BCE) loss measures the difference between the predicted probabilities and the true labels, penalizing the model mode when it makes incorrect confident preds. <br> 
This encourages the model to produce more accurate probablities for the true class labels. <br>
The loss function is minimized during training to update the model's parametres to improve its performance. <br>


$$
\text{BCE Loss} = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(p_i) + (1 - y_i) \log(1 - p_i) \right]
$$



The `SummaryWriter` class provides a high-level API to log data (scalars, images, histograms) in a given directory and add summaries and events to it.
The class updates the file contents asynchronously.    
These logged data can then be visualized in TensorBoard to gain insights into the model's behavior and performance over time.    

This allows a training program to call methods to add data to the file directly from the training loop, without slowing down the training.

<h3 style="color:#BF66F2 "> Methods: </h3>
<div style="margin-top: -25px;">

- add_scalar(tag, scalar_value, global_step=None, walltime=None): Log scalar data, such as loss or accuracy, with the specified tag.
- add_image(tag, img_tensor, global_step=None, walltime=None): Log images with the specified tag.
- add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None): Log histograms with the specified tag.
- add_graph(model, input_to_model=None, verbose=False): Log the computational graph of a PyTorch model.
- add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None): Log embeddings (e.g., for visualization using the Embedding Projector).
- add_text(tag, text_string, global_step=None, walltime=None): Log text data with the specified tag.
- add_figure(tag, figure, global_step=None, close=True, walltime=None): Log matplotlib figures with the specified tag.
</div>

In [None]:
""" Writes entries directly to event files in the log_dir to be consumed by TensorBoard. """
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

<h2 style="color:#BF66F2 "> Training </h3>
Discriminator: 

$
\max \log(D(x)) + \log(1 - D(G(z)))
$

Generator 
$
\max \log(D(x)) + \log(1 - D(G(z)))
\min \log(1 - D(G(z)))
$

That means GAN training trick)...to provide a stronger gradient during the early learning stages:
$ \max \log(D(G(z))$

In [None]:
""" Training.
N.B.1
The Discriminator is trained by maximizing the log-likelihood of classifying real data as real and fake data as fake.
The Generator is trained by minimizing the log-likelihood of the Discriminator classifying fake data as fake
(maximizing the log-likelihood of the Discriminator classifying fake data as real).

N.B.2
no_grad() => Context-manager (thread local)that disabled gradient calculation, to reduce memory consumption.
Useful for inference, when the backward() method is not called.

N.B.3
TensorBoard log uses the global_step value to specify the iteration number at which this data is added to TensorBoard.

"""

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        # Reshape the batch of data real into a 2D tensor with dimensions 784 for 28x28 images
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ######### Train Discriminator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen_obj(noise)

        ## Discriminator's predictions for real and fake data (flattening with view(-1))
        disc_real = disc_obj(real).view(-1)
        disc_fake = disc_obj(fake).view(-1)   #=> D(G(z))
        ### Discriminator's loss for real, fake data and overall
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        # Set gradients of all model parameters to zero
        disc_obj.zero_grad()
        # Backpropagate Discriminator's parameters
        lossD.backward(retain_graph=True)
        # Perform a single optimization step (parameter update)
        opt_disc.step()

        ################# Train Generator
        # where the second option of maximizing doesn't suffer from saturating gradients

        # Discriminator's predictions for fake data generated by the Generator
        output = disc_obj(fake).view(-1)
        # Generator loss
        lossG = criterion(output, torch.ones_like(output))
        ### Gradients to zero, backpropagate and update Generator's parameters
        gen_obj.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Print losses and to tensorboard
        if batch_idx == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                        Loss D: {lossD:.4f}, loss G: {lossG:.4f}")

            ####### Generate images using the fixed noise for visualization
            with torch.no_grad():
                #### Create image grids, the tensor needs to be reshaped to a 4-dimensional tensor with the shape:
                # batch_size: The number of images to display in a grid
                # 1:          The number of channels in the image => the images are grayscale
                # 28:         The height of the image
                # 28:         The width of the image
                fake = gen_obj(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                ## Log generated fake and real images to TensorBoard
                writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)
                writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6054, loss G: 0.6797
Epoch [1/50] Batch 0/1875                       Loss D: 0.2925, loss G: 1.3966
Epoch [2/50] Batch 0/1875                       Loss D: 0.4295, loss G: 1.0767
Epoch [3/50] Batch 0/1875                       Loss D: 0.6089, loss G: 0.9956
Epoch [4/50] Batch 0/1875                       Loss D: 0.6841, loss G: 0.8109
Epoch [5/50] Batch 0/1875                       Loss D: 0.3916, loss G: 1.4802
Epoch [6/50] Batch 0/1875                       Loss D: 0.7030, loss G: 1.2686
Epoch [7/50] Batch 0/1875                       Loss D: 0.7270, loss G: 1.6782
Epoch [8/50] Batch 0/1875                       Loss D: 0.7371, loss G: 0.9127
Epoch [9/50] Batch 0/1875                       Loss D: 0.5482, loss G: 1.1651
Epoch [10/50] Batch 0/1875                       Loss D: 0.4469, loss G: 1.1557
Epoch [11/50] Batch 0/1875                       Loss D: 0.4896, loss G: 1.3717
Epoch [12/50] Batch 0/1875                       L