# General Adversarial Networks

Generative Adversarial Networks (GANs) are a type of deep learning model that is used for generating new, synthetic data that resembles existing data. GANs consist of two neural networks:

1. Generator (G): This network takes a random noise vector as input and generates a synthetic data sample that is similar to the existing data.
2. Discriminator (D): This network takes a data sample (either real or synthetic) as input and outputs a probability that the sample is real.

The two networks are trained simultaneously, with the goal of improving the generator's ability to produce realistic data samples and the discriminator's ability to correctly classify real and synthetic data samples.

<img src="../images/gan.png" alt="GANs" width="600"/>


source: https://tikz.net/gan/

The training process works as follows:

1. The generator takes a random noise vector as input and generates a synthetic data sample.
2. The discriminator takes the synthetic data sample and outputs a probability that it is real.
3. The discriminator also takes a real data sample and outputs a probability that it is real.
4. The generator's goal is to produce synthetic data samples that are indistinguishable from real data samples, so it tries to maximize the probability that the discriminator will classify its output as real.
5. The discriminator's goal is to correctly classify real and synthetic data samples, so it tries to minimize the probability that it will incorrectly classify a synthetic data sample as real.
6. The generator and discriminator are updated iteratively, with the generator trying to improve its ability to produce realistic data samples and the discriminator trying to improve its ability to correctly classify real and synthetic data samples.

GANs have several applications, including:

1. Data augmentation: GANs can be used to generate new data samples that can be used to augment existing datasets, which can improve the performance of machine learning models.
2. Data generation: GANs can be used to generate new data samples that are similar to existing data, which can be used to simulate real-world scenarios or to generate new data for testing and evaluation.
3. Image and video synthesis: GANs can be used to generate new images and videos that are similar to existing ones, which can be used for applications such as video editing or image generation.
4. Style transfer: GANs can be used to transfer the style of one image to another, which can be used for applications such as image editing or artistic rendering.

GANs have several advantages, including:

1. Ability to generate high-quality, realistic data samples.
2. Ability to learn complex, non-linear relationships between the input data and the output data.
3. Ability to generate new data samples that are similar to existing data, but not identical.

However, GANs also have some limitations, including:

1. Difficulty in training: GANs can be difficult to train, especially for large datasets or complex tasks.
2. Mode collapse: GANs can suffer from mode collapse, which is a phenomenon where the generator produces a limited number of output samples that are similar to each other.
3. Unstable training: GANs can be prone to unstable training, which can cause the generator and discriminator to diverge or the training process to fail.

Overall, GANs are a powerful tool for generating new, synthetic data that resembles existing data, and have many applications in fields such as computer vision, natural language processing, and audio processing.

Training a Conditional Generative Adversarial Network (cGAN) on the FashionMNIST dataset enables the generation of fashion images conditioned on specific clothing categories. Here's a step-by-step guide to implementing a cGAN using PyTorch.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

## 2. Load and Preprocess the FashionMNIST Dataset**

Load the dataset and apply necessary transformations.
Define image transformations

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST(
    root='./data', train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:18<00:00, 1408012.44it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 144280.39it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:03<00:00, 1400309.63it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 1098564.08it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






## 3. Define the Generator and Discriminator Architectures

In a cGAN, both the generator and discriminator are conditioned on additional information, such as class labels.

### Generator

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim + num_classes, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        out = self.l1(gen_input)
        out = out.view(out.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

## Discriminator

In [4]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        validity = self.model(d_in)
        return validity

## 4. Initialize Models and Optimizers

Set up the models, loss function, and optimizers

In [5]:
latent_dim = 100
num_classes = 10
img_shape = (1, 28, 28)

# Initialize models
generator = Generator(latent_dim, num_classes, img_shape)
discriminator = Discriminator(num_classes, img_shape)

# Loss function
adversarial_loss = nn.MSELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)

MSELoss()

## 5. Training the cGAN

Train the generator and discriminator

In [6]:
# Training parameters
n_epochs = 200
sample_interval = 400

# Function to generate labels
def generate_labels(n, num_classes):
    return torch.randint(0, num_classes, (n,))

# Training loop
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(train_loader):

        batch_size = imgs.size(0)

        # Adversarial ground truths
        valid = torch.ones((batch_size, 1), requires_grad=False).to(device)
        fake = torch.zeros((batch_size, 1), requires_grad=False).to(device)

        # Configure input
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = torch.randn((batch_size, latent_dim)).to(device)
        gen_labels = generate_labels(batch_size, num_classes).to(device)

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Print training progress
        if i % sample_interval == 0:
            print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] "
                  f"[D loss: {d_loss.item():.6f}]")

[Epoch 0/200] [Batch 0/938] [D loss: 0.411009]
[Epoch 0/200] [Batch 400/938] [D loss: 0.226945]
[Epoch 0/200] [Batch 800/938] [D loss: 0.235815]
[Epoch 1/200] [Batch 0/938] [D loss: 0.200266]
[Epoch 1/200] [Batch 400/938] [D loss: 0.205378]
[Epoch 1/200] [Batch 800/938] [D loss: 0.184599]
[Epoch 2/200] [Batch 0/938] [D loss: 0.198941]
[Epoch 2/200] [Batch 400/938] [D loss: 0.205372]
[Epoch 2/200] [Batch 800/938] [D loss: 0.222276]
[Epoch 3/200] [Batch 0/938] [D loss: 0.189940]
[Epoch 3/200] [Batch 400/938] [D loss: 0.205558]
[Epoch 3/200] [Batch 800/938] [D loss: 0.180928]
[Epoch 4/200] [Batch 0/938] [D loss: 0.217088]
[Epoch 4/200] [Batch 400/938] [D loss: 0.219984]
[Epoch 4/200] [Batch 800/938] [D loss: 0.207456]
[Epoch 5/200] [Batch 0/938] [D loss: 0.225905]
[Epoch 5/200] [Batch 400/938] [D loss: 0.213363]
[Epoch 5/200] [Batch 800/938] [D loss: 0.205734]
[Epoch 6/200] [Batch 0/938] [D loss: 0.200807]
[Epoch 6/200] [Batch 400/938] [D loss: 0.226347]
[Epoch 6/200] [Batch 800/938] [D l

KeyboardInterrupt: 