<a href="https://colab.research.google.com/github/essegibi/isml/blob/main/WassersteinGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generative Adversial Networks

GANs mostly differ in their loss function, while similar in their basic construction [^1]

## Introduction

The compute_gradient_penalty function is used to calculate the gradient penalty loss for the Wasserstein GAN (Generative Adversarial Network) with gradient penalty. This loss term is added to the generator and discriminator loss functions to improve the stability and convergence of the GAN training process.

## Key Concepts

Before diving into the code, let's understand some key concepts related to the Wasserstein GAN and gradient penalty:

### Wasserstein GAN:

The Wasserstein GAN is a variant of the traditional GAN that uses the Wasserstein distance (also known as Earth Mover's distance) as a measure of similarity between the real and generated samples. It aims to minimize the distance between the distributions of real and generated samples, leading to more stable and meaningful training.

### Gradient Penalty:

The gradient penalty is a regularization technique used in the Wasserstein GAN to enforce the Lipschitz constraint on the discriminator. It penalizes the discriminator if its gradients with respect to the interpolated samples deviate from the norm of 1. This helps to prevent the discriminator from becoming too powerful and improves the overall training stability.

### Lipschitz Constraint

The Lipschitz constraint is a mathematical property that ensures the discriminator's function is not too sensitive to small changes in the input. In the context of GANs, enforcing the Lipschitz constraint helps prevent mode collapse and improves the overall quality of the generated samples.

## Code Structure

The compute_gradient_penalty function takes three parameters:

* discriminator: The discriminator model used in the GAN.
* real_samples: The batch of real samples used for training.
* fake_samples: The batch of generated (fake) samples produced by the generator.

The function calculates the gradient penalty loss for the Wasserstein GAN with gradient penalty and returns the computed gradient penalty.

## Conclusion

The compute_gradient_penalty function plays a crucial role in training the Wasserstein GAN with gradient penalty. It calculates the gradient penalty loss, which helps to enforce the Lipschitz constraint on the discriminator and improve the stability and convergence of the GAN training process. By penalizing deviations from a unit norm in the gradients of the discriminator's output with respect to the interpolated samples, the model is encouraged to have a smooth decision boundary and improve the overall stability and convergence of the training process.


[^1]: https://zzzcode.ai/

## CUDA

How to use t4 GPU in Google Colab

https://medium.com/analytics-vidhya/ml06-893e4cb389c6

In [None]:
torch.cuda.is_available()

True

In [None]:
!nvidia-smi

Sun Oct 15 13:52:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P8     9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
print(torch.ones(3,2))

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])


In [None]:
!cat /etc/*-release

DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=22.04
DISTRIB_CODENAME=jammy
DISTRIB_DESCRIPTION="Ubuntu 22.04.2 LTS"
PRETTY_NAME="Ubuntu 22.04.2 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.2 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy


## Dataset

In [None]:
!unzip file.zip # zip them, upload, then unzip it.

## GANs model

In [None]:
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable, grad
import torch
import torch.nn as nn

In [None]:
# Functions: Gradient Penalty Function, Encode Label, Decode Target

def compute_gradient_penalty(discriminator, real_samples, fake_samples):

    """Calculates the gradient penalty loss for the Wasserstein GAN with gradient penalty."""

    # Generate random interpolation

    """
    In this step, a random interpolation factor alpha is generated using the torch.rand function.
    This step creates a smooth transition between the real and fake samples.
    The interpolation is calculated as
    """

    alpha = torch.rand(real_samples.size(0), 1, 1, 1, dtype=torch.float32, requires_grad=True)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples))

    # Calculate Discriminator Output for Interpolated Samples

    """
    The interpolated samples are then passed through the discriminator model
    to obtain the discriminator output for each sample.
    """

    d_interpolates = discriminator(interpolates)

    # Create fake labels

    """
    A tensor of fake labels is created using the torch.full function.
    These labels are set to 1.0 and have the same shape as the real samples.
    It is used as the grad_outputs argument in the grad function,
    which calculates the gradients of d_interpolates with respect to interpolates.
    """

    fake = torch.full((real_samples.shape[0], 1), 1.0, dtype=torch.float32, requires_grad=True)

    # Calculate gradients w.r.t. interpolates

    """
    The gradients of the discriminator outputs with respect to the interpolated samples are calculated using the grad function from the torch.autograd module.
    The create_graph=True and retain_graph=True arguments ensure that the gradients can be used to compute higher-order gradients if needed.
    """

    gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake,
                     create_graph=True, retain_graph=True, only_inputs=True)[0]

    # Reshape and calculate norm

    """
    The gradients are reshaped to have a size of (batch_size, num_features) using the view method.
    Then, the norm of the gradients along the second dimension (dim=1) (across the features) is calculated
    using the norm method with p=2.
    The result is subtracted by 1 and squared to enforce the Lipschitz constraint.
    Finally, the squared norm minus 1 is averaged across the batch to obtain the gradient penalty.

    """
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


def encode_label(label):
    target = torch.zeros(6)
    for l in str(label).split(" "):
        target[int(l)] = 1.
    return target


def decode_target(target, text_labels=False, threshold=0.5):
    result = []
    for i, x in enumerate(target):
        if (x >= threshold):
            if text_labels:
                result.append(labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return " ".join(result)

In [None]:
# %% CUDA switch
"""traditional Wasserstein GAN suffering from exploding and vanishing gradients
due to interactions between the weight constraints and the cost function,
 which can only be remedied by a carefully chosen clipping threshold."""

cuda = True if torch.cuda.is_available() else False

print(cuda)

True


In [None]:
# %% Case and Image Shape
cases = ["Normal min-max", "Wasserstein loss", "Wasserstein with Gradient"]
case = cases[2]
img_shape = (3, 32, 32)  # NOTE: (image_channel, image_size, image_size)

In [None]:
# %% Generator
"""Captures the data distribution"""


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        # Define the neural network
        self.model = nn.Sequential(
            nn.Linear(100, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(np.prod(img_shape))),  # 32*32 = 1024
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)  # TODO: What is *image_shape
        return img

In [None]:
# %% Discriminator
""" Estimates the probability that a sample came from the training data rather than Generator"""


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        # Define the model
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

## Preprocess unconditione GANs

In [None]:
# %% Move Generator model, Discriminator Model to Cuda, FLoatTensor
""" Generator tries to min the adversarial loss,
    the discriminator also"""
generator = Generator()
if cuda:
    generator.cuda()

discriminator = Discriminator()
if cuda:
    discriminator.cuda()

Tensor = torch.cuda.FloatStorage if cuda else torch.FloatTensor



In [None]:
# %% Unconditioned: Create a directory for the sample images
if case == cases[0]:
    os.makedirs("images/gan_uncond_images", exist_ok=True)
elif case == cases[1]:
    os.makedirs("images/wgan_uncond_images", exist_ok=True)
elif case == cases[2]:
    os.makedirs("images/wgan_gp_uncond_images", exist_ok=True)

# NOTE: Keep the size to 32x32 to make it amenable the hardware
transform_ds = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

# %% Unconditioned: Dataset, Batch Size, Epochs, Dataloader, Image Shape
PATH_TO_DATASETS = "/Volumes/SGB256/PycharmProjects/ISML"
ROOT = PATH_TO_DATASETS + "/datasets/image_classification/seg_part/"

buildings_dataset = torchvision.datasets.ImageFolder(root=ROOT, transform=transform_ds)

batch_size = 512
num_epochs = 1000
buildings_dataloader = DataLoader(buildings_dataset, batch_size, shuffle=True, num_workers=0, pin_memory=True)

FileNotFoundError: ignored

In [None]:
# %% Unconditioned: Initialize Optimizer and create
if case == cases[0]:
    adversarial_loss = torch.nn.BCELoss()  # Binary Cross Entropy Loss
    if cuda:
        adversarial_loss.cuda()
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # TODO: what do betas?
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
elif case == cases[1]:
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=0.00005)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=0.00005)
elif case == cases[2]:
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


## Train unconditioned GANs

In [None]:
# %% Unconditioned: Loop
if case == cases[0]:
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(buildings_dataloader):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # TODO: What happens here?
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Create variable for the inputs    TODO: What does .type(Tensor)
            real_imgs = Variable(imgs.type(Tensor))

            # Generator training begins here
            optimizer_G.zero_grad()

            # Sample random noise as input  TODO: What is this noise?
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0],
                                                        100))))  # NOTE: `100` is the size of the latent dimensionality of the model, adjusting it can have huge effects on the training dynamics

            # Generate images
            fake_imgs = generator(z)

            # Take the loss
            g_loss = adversarial_loss(discriminator(fake_imgs), valid)
            g_loss.backward()  # Gradient of loss wrt parameters
            optimizer_G.step()  # Optimizes parameters

            # Discriminator training begins here
            optimizer_D.zero_grad()

            # Assess the discriminators classification ability
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)

            # Loss of the discriminator
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, num_epochs, i, len(buildings_dataloader), d_loss.item(), g_loss.item()))

            batches_done = epoch * len(buildings_dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake_imgs.data[:25], "images/gan_uncond_images/%d.png" % batches_done, nrow=5,
                           normalize=True)
elif case == cases[1]:
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(buildings_dataloader):

            # Create a variable for the inputs
            real_imgs = Variable(imgs.type(Tensor))

            # Begin with the discriminator training here
            optimizer_D.zero_grad()

            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0],
                                                        100))))  # NOTE: `100` is the size of the latent dimensionality of the model, adjusting it can have huge effects on the training dynamics

            fake_imgs = generator(z).detach()
            d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
            d_loss.backward()

            optimizer_D.step()

            # Clip weights of the discriminator at the end of the iteration
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Train the generator only every fifth iteration
            if i % 5 == 0:
                optimizer_G.zero_grad()

                gen_imgs = generator(z)
                g_loss = -torch.mean(discriminator(gen_imgs))
                g_loss.backward()
                optimizer_G.step()

                print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, num_epochs, i, len(buildings_dataloader), d_loss.item(), g_loss.item()))

        batches_done = epoch * len(buildings_dataloader) + i
        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], "images/wgan_uncond_images/%d.png" % batches_done, nrow=5, normalize=True)

elif case == cases[2]:
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(buildings_dataloader):

            # Create a variable for the inputs
            real_imgs = Variable(imgs.type(Tensor), requires_grad=True)

            # Begin with the discriminator training here
            optimizer_D.zero_grad()

            # NOTE: `100` is the size of the latent dimensionality of the model, adjusting it can have huge effects on the training dynamics
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100))))
            fake_imgs = generator(z).detach()
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(
                discriminator(fake_imgs)) + 10 * gradient_penalty
            d_loss.backward()
            optimizer_D.step()

            # Clip weights of the discriminator at the end of the iteration
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Train the generator only every fifth iteration
            if i % 5 == 0:
                optimizer_G.zero_grad()

                gen_imgs = generator(z)
                g_loss = -torch.mean(discriminator(gen_imgs))
                g_loss.backward()
                optimizer_G.step()

                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                      % (epoch, num_epochs, i, len(buildings_dataloader), d_loss.item(), g_loss.item()))
            batches_done = epoch * len(buildings_dataloader) + i
            if batches_done % 400 == 0:
                save_image(gen_imgs.data[:25],
                           "images/wgan_gp_uncond_images/%d.png" % batches_done,
                           nrow=5, normalize=True)

## Preprocess conditioned GAN

In [None]:
# %% Conditional GANs: Labels

labels = {
    0: "buildings",
    1: "forest",
    2: "glaciers",
    3: "mountain",
    4: "sea",
    5: "street"
}

# %% Conditional GANs: Transforms TODO: Why two different transforms?
transform_ds = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

transform_dt = transforms.Compose([transforms.ToTensor(), transforms.Normalize([.5, .5, .5], [.5, .5, .5])])

# %% Conditional GANs: Whole Dataset; Train + Test

ROOT = PATH_TO_DATASETS + "/datasets/image_classification/seg_train"
dataset = torchvision.datasets.ImageFolder(root=ROOT, transform=transform_ds)

ROOT = PATH_TO_DATASETS + "/datasets/image_classification/seg_test"
test_dataset = torchvision.datasets.ImageFolder(root=ROOT, transform=transform_dt)

# %% Conditional GANs: Batch Size, Epochs, Dataloader
batch_size = 512

# Define the length of the training for the conditional GANs
num_epochs = 1000

# Define the dataloader
whole_dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)

# %% Conditional GANs: Create a directory for the sample images
if case == cases[0]:
    os.makedirs("images/gan_cond_images", exist_ok=True)
elif case == cases[1]:
    os.makedirs("images/wgan_cond_images", exist_ok=True)
elif case == cases[2]:
    os.makedirs("images/wgan_gp_cond_images", exist_ok=True)

# %% Conditional GANs: Initialize optimizer
if case == cases[0]:
    adversarial_loss = torch.nn.BCELoss()

    if cuda:
        adversarial_loss.cuda()

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
elif case == cases[1]:
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=0.00005)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=0.00005)
elif case == cases[2]:
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


## Train conditioned GANs

In [None]:
# %% Conditional GANs: Loop
if case == cases[0]:

    for epoch in range(num_epochs):
        for i, (imgs, labels) in enumerate(whole_dataloader):

            # Normalize and reshape the labels
            labels = (1 / max(labels)) * labels
            labels = labels.view(imgs.shape[0],
                                 1)  # NOTE: This is required for the later concatenation [imgs.shape[0]] -> [imgs.shape[0], 1]

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Create variable for the inputs
            real_imgs = Variable(imgs.type(Tensor))

            # Generator training begins here
            optimizer_G.zero_grad()

            # Sample random noise as input and conditon with labels
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100 - 1))))
            z_cond = torch.cat((z, labels), dim=1)

            # Generate images
            fake_imgs = generator(z_cond)

            # Take the loss
            g_loss = adversarial_loss(discriminator(fake_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            # Discriminator training begins here
            optimizer_D.zero_grad()

            # Assess the discriminators classification ability
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)

            # Loss of the discriminator
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, num_epochs, i, len(whole_dataloader), d_loss.item(), g_loss.item())
            )

            batches_done = epoch * len(whole_dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake_imgs.data[:25], "images/gan_cond_images/%d.png" % batches_done, nrow=5, normalize=True)

elif case == cases[1]:

    for epoch in range(num_epochs):
        for i, (imgs, labels) in enumerate(whole_dataloader):

            # Normalize and reshape the labels
            labels = (1 / max(labels)) * labels
            labels = labels.view(imgs.shape[0],
                                 1)  # NOTE: This is required for the later concatenation [imgs.shape[0]] -> [imgs.shape[0], 1]

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Create variable for the inputs
            real_imgs = Variable(imgs.type(Tensor))

            # Discriminator training begins here
            optimizer_D.zero_grad()

            # Sample random noise as input and conditon with labels
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100 - 1))))
            z_cond = torch.cat((z, labels), dim=1)

            # Generate images
            fake_imgs = generator(z_cond).detach()

            # Take the loss
            d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
            d_loss.backward()

            optimizer_D.step()

            # Clip weights of the discriminator at the end of the iteration
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Train the generator only every fifth iteration
            if i % 5 == 0:
                optimizer_G.zero_grad()

                gen_imgs = generator(z_cond)
                g_loss = -torch.mean(discriminator(gen_imgs))
                g_loss.backward()
                optimizer_G.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                  % (epoch, num_epochs, i, len(whole_dataloader), d_loss.item(), g_loss.item()))

            batches_done = epoch * len(whole_dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake_imgs.data[:25], "images/gan_cond_images/%d.png" % batches_done, nrow=5, normalize=True)

elif case == cases[2]:
    for epoch in range(num_epochs):
        for i, (imgs, labels) in enumerate(whole_dataloader):

            # Normalize and reshape the labels
            labels = (1 / max(labels)) * labels
            labels = labels.view(imgs.shape[0],
                                 1)  # NOTE: This is required for the later concatenation [imgs.shape[0]] -> [imgs.shape[0], 1]

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Create variable for the inputs
            real_imgs = Variable(imgs.type(Tensor))

            # Generator training begins here
            optimizer_G.zero_grad()

            # Sample random noise as input and conditon with labels
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100 - 1))))
            z_cond = torch.cat((z, labels), dim=1)

            # Generate images
            fake_imgs = generator(z_cond).detach()

            # Take the loss
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) \
                     + 10 * gradient_penalty

            d_loss.backward()
            optimizer_D.step()

            # Clip weights of the discriminator at the end of the iteration
            for p in discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Train the generator only every fifth iteration
            if i % 5 == 0:
                optimizer_G.zero_grad()

                gen_imgs = generator(z_cond)
                g_loss = -torch.mean(discriminator(gen_imgs))
                g_loss.backward()
                optimizer_G.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                  % (epoch, num_epochs, i, len(whole_dataloader), d_loss.item(), g_loss.item()))

            batches_done = epoch * len(whole_dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake_imgs.data[:25], "images/gan_cond_images/%d.png" % batches_done, nrow=5, normalize=True)
