# Lab 2: Generative Models (Generative adversarial networks)
```
- [S25] Advanced Machine Learning, Innopolis University
- Teaching Assistant: Gcinizwe Dlamini
```
<hr>


```
Lab Plan
1. Conditional Generative adversarial networks
2. Bidirectional Generative Adversarial Network
3. Task 2
```

<hr>

## 1. Conditional Generative adversarial network


The Conditional Generative Adversarial Network (cGAN) is a model used in deep learning, a derivative of machine learning. It enables more precise generation and discrimination of data. The achitecture is similar to the one of vanilla GAN.


The condition for data generation is a link to the label of the data.

![Conditional GAN](https://www.researchgate.net/profile/Gerasimos-Spanakis/publication/330474693/figure/fig1/AS:956606955139072@1605084279074/GAN-conditional-GAN-CGAN-and-auxiliary-classifier-GAN-ACGAN-architectures-where-x_Q320.jpg)

### Imports

`!pip install tensorboardX`

In [1]:
!pip install tensorboardX

Collecting tensorboardX
  Using cached tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Using cached tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder, MNIST
from torchvision import transforms
from torch import autograd
from torchvision.utils import make_grid
import torchvision
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### 1.1 Dataset preparation

In this task we will use MNIST dataset

In [3]:
batch_size = 32
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
])



def load_dataset(batch_size = 128, root = './data', transform=transforms.ToTensor()):
    train_dataset = torchvision.datasets.MNIST(root = root + '/MNIST', train=True,
                                               transform=transform, download=True)

    test_dataset = torchvision.datasets.MNIST(root = root + '/MNIST', train=False,
                                              transform=transform)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

train_loader, _ = load_dataset(batch_size=batch_size, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.10MB/s]


Extracting ./data/MNIST\MNIST\raw\train-images-idx3-ubyte.gz to ./data/MNIST\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 113kB/s]


Extracting ./data/MNIST\MNIST\raw\train-labels-idx1-ubyte.gz to ./data/MNIST\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.55MB/s]


Extracting ./data/MNIST\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/MNIST\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 790kB/s]


Extracting ./data/MNIST\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/MNIST\MNIST\raw



### 1.2 Define Discriminator model

In [4]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()

    self.label_emb = nn.Embedding(10, 10)

    self.model = nn.Sequential(
        nn.Linear(794, 1024),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(1024, 512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(256, 1),
        nn.Sigmoid()
    )

  def forward(self, x, labels):
    x = x.view(x.size(0), 784)
    c = self.label_emb(labels)
    x = torch.cat([x, c], 1)
    out = self.model(x)
    return out.squeeze()

### 1.3 Define Generator model

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_emb = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0), 28, 28)

### 1.4 Define Conditional GAN

In [6]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

### 1.5 Define Training params

In [7]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

writer = SummaryWriter()

### 1.6 Generator Training procedure

In [8]:
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device='cpu'):

  g_optimizer.zero_grad()
  z = torch.randn(batch_size, 100).to(device)
  fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)).to(device)
  fake_images = generator(z, fake_labels)
  validity = discriminator(fake_images, fake_labels)
  g_loss = criterion(validity, torch.ones(batch_size).to(device))
  g_loss.backward()
  g_optimizer.step()
  return g_loss.item()

### 1.7 Discriminator Training procedure

In [9]:
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels, device='cpu'):
    d_optimizer.zero_grad()

    # train with real images
    real_validity = discriminator(real_images, labels)
    real_loss = criterion(real_validity, torch.ones(batch_size).to(device))

    # train with fake images
    z = torch.randn(batch_size, 100).to(device)
    fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)).to(device)
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    fake_loss = criterion(fake_validity, torch.zeros(batch_size).to(device))

    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss.item()

### 1.8 Conditional GAN training loop

In [10]:
num_epochs = 50
n_critic = 5
display_step = 10
for epoch in range(num_epochs):
    print('Epoch {} ...'.format(epoch), end=' ')
    for i, (images, labels) in enumerate(train_loader):

        step = epoch * len(train_loader) + i + 1
        real_images = images.to(device)
        labels = labels.to(device)
        generator.train()

        d_loss = 0
        for _ in range(n_critic):
            d_loss = discriminator_train_step(len(real_images), discriminator,
                                              generator, d_optimizer, criterion,
                                              real_images, labels, device=device)


        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device=device)

        writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / n_critic)}, step)

        if step % display_step == 0:
            generator.eval()
            z = torch.randn(9, 100).to(device)
            labels = torch.LongTensor(np.arange(9)).to(device)
            sample_images = generator(z, labels).unsqueeze(1)
            grid = make_grid(sample_images, nrow=3, normalize=True)
            writer.add_image('sample_image', grid, step)
    print('Done!')

Epoch 0 ... Done!
Epoch 1 ... Done!
Epoch 2 ... Done!
Epoch 3 ... Done!
Epoch 4 ... Done!
Epoch 5 ... Done!
Epoch 6 ... Done!
Epoch 7 ... 

KeyboardInterrupt: 

### 1.9 Generate Data

In [None]:
def generate_digit(generator, digit):
    z = torch.randn(1, 100).to(device)
    label = torch.LongTensor([digit]).to(device)
    img = generator(z, label).data.cpu()
    img = 0.5 * img + 0.5
    return transforms.ToPILImage()(img)

In [None]:
generator.eval()
generate_digit(generator, 2)

## 2. Structure of a Bidirectional Generative Adversarial Network (BiGAN)
A BiGAN, or Bidirectional GAN, is a type of generative adversarial network where the generator not only maps latent samples to generated data, but also has an inverse mapping from data to the latent representation.
![](https://ar5iv.labs.arxiv.org/html/1801.04271/assets/bigan.png)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### 2.1 Load Data

In [12]:
batch_size = 128
mnist_train, mnist_test = load_dataset(batch_size=batch_size)

### 2.2 Define Generator

- **Role:** Takes a latent vector \( z \) (sampled from a simple prior, like a Gaussian) and generates a synthetic data sample \( G(z) \) (e.g., an image).
- **Goal:** Produce realistic outputs so that the discriminator cannot distinguish them from real data.

In [13]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(50, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.layers(z)

### 2.3 Define Discriminator

- **Role:** Receives pairs of data and latent codes. It sees:
  - **Real pair:** \((x, E(x))\) where \( x \) is a real data sample and \( E(x) \) is its encoded latent representation.
  - **Fake pair:** \((G(z), z)\) where \( z \) is a sampled latent vector and \( G(z) \) is the generated data sample.
- **Goal:** Distinguish between real and fake pairs by outputting a probability that a given pair is real.


In [14]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28 + 50, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024,1),
            nn.Sigmoid()
        )

    def forward(self, X, z):
        Xz = torch.cat([X, z], dim=1)
        return self.layers(Xz)

### 2.4 Define Encoder

- **Role:** Maps a real data sample \( x \) (e.g., an image) to a latent representation \( E(x) \).
- **Goal:** Learn an inverse mapping of the generator, so that the latent code encapsulates meaningful information about the data.

In [15]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 50)
        )

    def forward(self, X):
        return self.layers(X)


### 2.5 Define Loss function for discriminator and Encoder-Generator

4. **Discriminator Training:**
   - **Objective:** Update $( D )$ so that it assigns a high probability (close to 1) to the real pair and a low probability (close to 0) to the fake pair.


5. **Generator and Encoder Training:**
   - **Objective:** Update $( G )$ and $( E )$ together so that the fake pairs $((G(z), z))$ become more similar to the real pairs $((x, E(x)))$, effectively “fooling” $(D)$

In [16]:
def D_loss(DG, DE, eps=1e-6):
    loss = torch.log(DE + eps) + torch.log(1 - DG + eps)
    return -torch.mean(loss)

In [17]:
def EG_loss(DG, DE, eps=1e-6):
    loss = torch.log(DG + eps) + torch.log(1 - DE + eps)
    return -torch.mean(loss)

### 2.6 Define BiGAN and Training Params

In [21]:
n_epochs = 400
l_rate = 2e-5

E = Encoder().to(device)
G = Generator().to(device)
D = Discriminator().to(device)

#optimizers with weight decay
optimizer_EG = torch.optim.Adam(list(E.parameters()) + list(G.parameters()),
                                lr=l_rate, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_D = torch.optim.Adam(D.parameters(),
                               lr=l_rate, betas=(0.5, 0.999), weight_decay=1e-5)

## 3. Task 2
```
Task 2.1
- Train defined BiGAN above
- In the training procedure add tensorboard and every 10 epochs visualize 10
  - generated images
  - reconstructed images
```
<hr>

```
Task 2.2
- Implement and train a conditional BiGAN for CIFAR10 dataset
```

In [24]:
writer = SummaryWriter()

for epoch in range(n_epochs):
    for batch_idx, (real_images, _) in enumerate(mnist_train):
        real_images = real_images.view(-1, 28 * 28).to(device)
        batch_size_real = real_images.shape[0]  
        z_random = torch.randn(batch_size_real, 50).to(device)  
        z_encoded = E(real_images)
        
        X_generated = G(z_random)
        D_real = D(real_images, z_encoded)
        D_fake = D(X_generated, z_random)
        
        loss_D = D_loss(D_fake, D_real)
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        D_fake = D(G(z_random), z_random)
        D_real = D(real_images, E(real_images))
        
        loss_EG = EG_loss(D_fake, D_real)
        optimizer_EG.zero_grad()
        loss_EG.backward()
        optimizer_EG.step()
        
    # Logging to TensorBoard
    writer.add_scalar('Loss/Discriminator', loss_D.item(), epoch)
    writer.add_scalar('Loss/Generator', loss_EG.item(), epoch)
    
    if epoch % 10 == 0:
        with torch.no_grad():
            generated_images = G(z_random).view(-1, 1, 28, 28)
            reconstructed_images = G(E(real_images)).view(-1, 1, 28, 28)
            
            writer.add_images('Generated', generated_images[:10], epoch)
            writer.add_images('Reconstructed', reconstructed_images[:10], epoch)

writer.close()


KeyboardInterrupt: 

## Resources

* [Simple Explaination of BiGAN](https://youtu.be/rzpA0H-q_HY)
* [Adversarial feature learning](https://arxiv.org/pdf/1605.09782v7.pdf) -- original BiGAN paper