# Generative Adversarial Networks (GANs)

Author: https://www.github.com/deburky

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. They were introduced by Ian Goodfellow et al. in 2014.

* Consists of two neural networks: a generator and a discriminator.
* The generator generates new data instances, while the discriminator evaluates them for authenticity.
* The generator is trained to fool the discriminator, and the discriminator is trained to recognize the generated data as fake.

Reference implementations
---
PyTorch GAN: [PyTorch GANs](https://github.com/eriklindernoren/PyTorch-GAN)

PyTorch Lightning GAN: [PyTorch Lightning Basic GAN Tutorial](https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html)

## Digits Dataset

The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. 

In [None]:
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset

import matplotlib.pyplot as plt
import numpy as np

# Define the transformation
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load the dataset
batch_size = 16
num_workers = 4
mnist_dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=False)

def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# Define the subset size
subset_size = int(0.1 * len(mnist_dataset))
indices = np.random.choice(len(mnist_dataset), subset_size, replace=True)
subset = Subset(mnist_dataset, indices)

# Create a DataLoader for the subset
data_loader = DataLoader(
    subset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True
)

# Get a batch of images
data_iter = iter(data_loader)
images, labels = next(data_iter)

# Show images
imshow(torchvision.utils.make_grid(images[:4]))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import os
import gc

torch.manual_seed(0)
torch.set_num_threads(2)

# Initialize device for MPS
device = torch.device("mps")
torch.mps.manual_seed(0)
torch.mps.set_per_process_memory_fraction(0.5)

# Create a directory for saving images
os.makedirs("gan_training_images", exist_ok=True)

# Weight initialization function
def weights_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Set parameters
latent_dim = 128
img_shape = (1, 28, 28)
num_epochs = 100
lr = 1e-3

# Generator Model
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

# Initialize models
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# Apply weight initialization
# generator.apply(weights_init)
# discriminator.apply(weights_init)

# Loss and optimizers
adversarial_loss = nn.BCELoss().to(device)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
for epoch in range(num_epochs):
    for imgs, _ in data_loader:
        imgs = imgs.to(device)
        # Ground truths
        real_labels = torch.ones((imgs.size(0), 1), device=device, requires_grad=False)
        fake_labels = torch.zeros((imgs.size(0), 1), device=device, requires_grad=False)

        # Train Generator
        optimizer_G.zero_grad()

        # Generate fake images from noise
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)

        # Generator loss: Fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        # Real images
        real_loss = adversarial_loss(discriminator(imgs), real_labels)

        # Fake images
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake_labels)

        # Total Discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}] \t Real Loss: {real_loss:.4f} \t G Loss: {g_loss.item():.4f} \t D Loss: {d_loss.item():.4f}")

    # Save generated images every few epochs
    if (epoch + 1) % 2 == 0:  # Save every 2 epochs
        fixed_noise = torch.randn(64, latent_dim, device=device)
        with torch.no_grad():
            generated_images = generator(fixed_noise)
        save_path = f"gan_training_images/epoch_{epoch + 1}.png"
        vutils.save_image(generated_images, save_path, nrow=8, normalize=True)
        print(f"Saved generated images to {save_path}")
        
        del imgs, real_labels, fake_labels, real_loss, fake_loss, d_loss, gen_imgs
        gc.collect()

In [44]:
import imageio
import os

images_dir = "gan_training_images"
image_files = [img for img in sorted(os.listdir(images_dir)) if img.endswith(".png")]
images = [imageio.v3.imread(f"{images_dir}/{img}") for img in image_files]

imageio.mimsave("gan_training_images/training_animation.gif", images, fps=2)

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import numpy as np

# Function to display real and generated images
def display_real_and_fake_images(real_images, fake_images, num_images=16):

    nrows = int(np.sqrt(num_images))

    plt.figure(figsize=(5, 5))

    # Real images
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(
                real_images[:num_images], nrow=nrows, padding=2, normalize=True
                ).cpu(), (1, 2, 0)
        )
    )

    # Generated images
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Generated Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(
                fake_images[:num_images], nrow=nrows, padding=2, normalize=True
                ).cpu(), (1, 2, 0)
        )
    )
    plt.show()

# Generate images with the trained generator
fixed_noise = torch.randn(64, latent_dim, device=device)
with torch.no_grad():
    generated_images = generator(fixed_noise)

# Obtain a batch of real images for comparison
real_batch = next(iter(data_loader))[0]

# Display real and generated images
display_real_and_fake_images(real_batch, generated_images, num_images=4)