In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets


# WGAN with gradient penalty (WGAN-GP)
* Paper Link: [https://arxiv.org/abs/1704.00028](https://arxiv.org/abs/1704.00028)
* Best explained in this [blog](https://machinelearningmastery.com/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch/) and math is explained better in this [video](https://www.youtube.com/watch?v=QJOEmwvnmTM) and this [video](https://www.youtube.com/watch?v=pG0QZ7OddX4) 
* History of GAN's is best explained in this [blog](https://lilianweng.github.io/posts/2017-08-20-gan/)
* The Wasserstein Generative Adversarial Network, or WGAN for short, is an extension to the generative adversarial network that both improves the stability when training a GAN and provides a loss function that correlates with the quality of generated images.

<!-- Algorithm for WGAN -->
## Algorithm for WGAN
1. Train the critic model for a fixed number of iterations.
2. Update the critic model using the critic loss function.
3. Train the generator model.
4. Update the generator model using the generator loss function.


* Loss function for critic model:
    * critic_loss = -mean(critic_output_real) + mean(critic_output_fake)
    * Basically, E[critic(real)] - E[critic(fake)] where E is the expectation, and critic tries to maximize this. But as loss is to be minimized, we take negative of it.
    * Generator tries to minimize this loss.
  
* And in this model we can come to know the model generates good images by looking at the critic loss. If the critic loss is low, then the images generated are good. That is ideally the critic loss should be 0. So we train the generator to minimize the critic loss.

* Loss function for generator model:
    * generator_loss = -mean(critic_output_fake)
    * Generator tries to maximize this loss, i.e., E[critic(fake)]. But as loss is to be minimized, we take negative of it.
    * What this means is that the generator tries to generate images that the critic will classify as real. So, the generator tries to maximize the critic's output for fake images.


For every n updates to the critic, we update the generator once. And n is a hyperparameter that can be tuned.

* WGAN-GP is an improved version of WGAN.
* WGAN-GP uses gradient penalty to enforce the Lipschitz constraint on the critic.

* The critic is trained to output a scalar value for each input image. And the critic is trained to output a higher value for real images and a lower value for fake images.
* And to keep the critic within the Lipschitz constraint, we add a gradient penalty term to the loss function.

* What is the Lipschitz constraint?<br>
    The Lipschitz constraint is a mathematical property that bounds the rate of change of a function. For example, for y = 2x+3, the Lipschitz constant is 2, means for increase of 1 in x, y will increase by 2.
    In the context of GANs, the Lipschitz constraint is used to bound the rate of change of the critic. This is important because if the critic is not bounded, it can output very high values for fake images and very low values for real images, which can make the training unstable and explode the gradients.

* In normal WGAN, we enforce the Lipschitz constraint by clipping the critic weights to a limited range after each mini-batch update. But this can lead to vanishing gradients and mode collapse. So, in WGAN-GP, we use gradient penalty to enforce the Lipschitz constraint without clipping the critic weights.

* The gradient penalty term is calculated as the norm of the gradients of the critic with respect to the input images. This term is added to the loss function to enforce the Lipschitz constraint.
  
* The differences in implementation for the WGAN are as follows:
    1. Use a linear activation function in the output layer of the critic model (instead of sigmoid).
    2. Use -1 labels for real images and 1 labels for fake images (instead of 1 and 0).
    3. Use Wasserstein loss to train the critic and generator models.
    4. Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).
    5. Update the critic model more times than the generator each iteration (e.g. 5).
    6. Use the RMSProp version of gradient descent with a small learning rate and no momentum (e.g. 0.00005).

What is that wgan is able to solve when compared to the previous versions like DCGAN?

Wasserstein Generative Adversarial Networks (WGANs) improve upon previous versions like Deep Convolutional GANs (DCGANs) in several ways:

1. Stability of Training: Traditional GANs, including DCGANs, can suffer from mode collapse and training instability. WGANs address this issue by introducing a new training objective based on Wasserstein distance (also known as Earth Mover's Distance), which provides a more stable and meaningful measure of the difference between probability distributions. This stability leads to more reliable training and fewer problems like mode collapse.
3. Meaningful Loss Metric: In traditional GAN training, the discriminator's output doesn't necessarily correspond to meaningful probability values. In contrast, the WGAN training objective encourages the discriminator to output values that approximate the Wasserstein distance between the real and generated distributions. This leads to discriminator outputs that can be interpreted as meaningful measures of the similarity between the real and generated data distributions.
4. Improved Gradient Flow: The WGAN training objective, based on Wasserstein distance, provides smoother gradients throughout training compared to traditional GANs. This smoother gradient flow can lead to faster convergence and more stable training dynamics.
5. More Meaningful Generator Updates: In WGAN training, the generator is updated based on the gradients of the discriminator with respect to the generated samples. This means that the generator's updates are more directly related to how the generated samples are perceived by the discriminator, leading to more meaningful updates and potentially better sample quality.







In [None]:
# Implementing the Generator for WGAN

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
# Implementing the WGAN
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-4
z_dim = 100
image_dim = 64
batch_size = 64
num_epochs = 5
features_disc = 16
features_gen = 16
channels_img = 1
critic_iter = 5
lambda_gp = 10

In [None]:
# Load the data
transform = transforms.Compose([
    transforms.Resize(image_dim),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]
    )
])

dataloader = DataLoader(
    datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform,
    ),
    batch_size=batch_size,
    shuffle=True,
)


In [None]:
# Initialize the generator and discriminator
gen = Generator(z_dim, channels_img, features_gen).to(device)
critic = Discriminator(channels_img, features_disc).to(device)

# Initialize the weights of the generator and discriminator
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
critic = critic.apply(weights_init)

In [None]:

# Optimizers
gen_opt = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
critic_opt = optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))

In [None]:

# Gradient Penalty
def calc_gradient_penalty(critic, real, fake, device='cpu'):
    BATCH_SIZE, C, H, W = real.shape 
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) # epsilon is a random number between 0 and 1 for each image in the batch, and is repeated for each channel
    interpolated_images = real * epsilon + fake * (1 - epsilon)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:

# Training the WGAN
gen.train()
critic.train()

cur_step = 0
generator_losses = []
discriminator_losses = []

for epoch in range(num_epochs):
    for real, _ in dataloader:
        real = real.to(device)
        cur_batch_size = len(real)

        for _ in range(critic_iter):
            noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1) # Flatten the critic scores for the real images
            critic_fake = critic(fake).reshape(-1) # Flatten the critic scores for the fake images
            gp = calc_gradient_penalty(critic, real, fake, device=device)
            critic_loss = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp
            ) # Maximize the critic loss so we minimize the negative of the critic loss
            
            # We are adding the gradient penalty to the loss in order to enforce the Lipschitz constraint, which is a key component of WGANs
            # The Lipschitz constraint is a way to enforce that the critic (discriminator) does not become too powerful, which can lead to mode collapse
            # The Lipschitz constraint is enforced by making sure that the gradient of the critic with respect to the input is always less than a certain value
            # The gradient penalty is a way to enforce this constraint by adding a penalty term to the loss that penalizes the gradient of the critic with respect to the input
            # The gradient penalty is calculated by taking the gradient of the critic scores with respect to the input images, and then taking the norm of the gradient
            # The norm of the gradient is then squared and the mean is taken to get the gradient penalty
            # The gradient penalty is then added to the critic loss to enforce the Lipschitz constraint

            # In normal WGANs, lipshitz constraint is enforced by clipping the weights of the critic to a certain range (e.g. -0.01 to 0.01) after each update but this can lead to mode collapse and other issues
            # The gradient penalty is a more stable way to enforce the Lipschitz constraint
            
            critic.zero_grad()
            critic_loss.backward(retain_graph=True)
            critic_opt.step()

        # Generator update
        gen.zero_grad()
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        critic_fake = critic(fake).reshape(-1)
        gen_loss = -torch.mean(critic_fake)
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average discriminator loss
        discriminator_losses += [critic_loss.item()]
        generator_losses += [gen_loss.item()]

        # Visualization code
        if cur_step % 500 == 0 and cur_step > 0:
            print(
                f"Step {cur_step}: Generator loss: {gen_loss.item()}, discriminator loss: {critic_loss.item()}"
            )
            noise = torch.randn(9, z_dim, 1, 1).to(device)
            fake = gen(noise)
            fake = fake.detach().cpu()
            img_grid = torchvision.utils.make_grid(fake, nrow=3)
            plt.imshow(np.transpose(img_grid, (1, 2, 0)))
            plt.show()

        cur_step += 1

# Save the model
torch.save(gen.state_dict(), "WGAN_Generator.pth")
torch.save(critic.state_dict(), "WGAN_Discriminator.pth")

In [None]:

# Plotting the losses
plt.figure(figsize=(10, 5))
plt.plot(generator_losses, label='Generator Loss')
plt.plot(discriminator_losses, label='Discriminator Loss')
plt.legend()
plt.title('Losses')
plt.show()

Observation:
Even 5 epochs of training can generate good images.

Results after 5 epochs of training:

![image](../images/wgangp-result.png)