# GANs
In this assignment, you will first implement the original GAN and the Wasserstein GAN (WGAN) on a toy problem to see how the relatively small changes can lead to big changes during training. Then, you will train a Conditional GAN (cGAN) on the same dataset as in the previos assignment, namely MNIST.

## Setup
To facilitate the assignment, we use the same enviorments as in the previous assignments. If you installed the environment in the previous assignment, you can simply do `conda activate vae`. Otherwise, run the following:
```
conda env create -f environments/environment-gpu.yml
conda activate vae
```

or

```
conda env create -f environments/environment-cpu-only.yml
conda activate vae
```

## How to complete this assignment

Throughout this assignment there are several places where you will need to fill in code. These are marked with `YOUR CODE HERE` comments. Further, there are several places where you will need to answer questions. These are marked with `YOUR ANSWER HERE` comments. You should replace the `YOUR CODE HERE` and `YOUR ANSWER HERE` comments with your code and answers. 

---

In [None]:
import os
import os.path as osp
import numpy as np
import time
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(1337)
np.random.seed(1337)


# Part 2: Conditional GAN
When using GANs, it could be very useful to generate a particular type of images from the full data distribution. In this part, we will implement a conditional GAN (cGAN) that can generate images of specific hand-written digits using the MNIST dataset.

Initially, we want you to implement the conditional generator. It should be a MLP with the follwing architecture:
- An input node that takes the concatenated noise vector and the embedded label as input.
- A LeakyReLU layer with a slope of 0.2.
- 3 fully connected layers with 256, 512 and 1024 nodes respectively, which all should be follwed by a BatchNorm and a LeakyReLU activation.
- An output layer with im_height * im_width * im_channels nodes and a Tanh activation.

The forward pass should: 
- Embed the label
- Concatenate the noise vector and the embedded label.
- Pass the concatenated vector through the MLP.
- Reshape the output to the correct image size.

In [None]:
class ConditionalGenerator(nn.Module):
    def __init__(self, n_classes: int, n_channels: int, img_size: int, latent_dim: int):
        super(ConditionalGenerator, self).__init__()
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.n_channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.n_classes, self.n_classes)

        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, noise: torch.Tensor, labels: torch.Tensor):
        """Generates an image given a noise vector and a label.

        Args:
            noise (torch.Tensor): A noise vector of shape (batch_size, latent_dim)
            labels (torch.Tensor): A label vector of shape (batch_size)
        Returns:
            torch.Tensor: A generated image of shape (batch_size, n_channels, img_size, img_size)
        """
        # YOUR CODE HERE
        raise NotImplementedError()
        return x

In [None]:
tmp_generator = ConditionalGenerator(10, 1, 28, 100)
tmp_noise = torch.randn(10, 100)
tmp_labels = torch.randint(0, 10, (10,))
tmp_img = tmp_generator(tmp_noise, tmp_labels)
assert tmp_img.shape == (10, 1, 28, 28), "Wrong shape of generated image"

Here, we want you to implement the conditional discriminator. It should be a MLP with the follwing architecture:
- An input node that takes the flattened image concatenated with the embedded label as input, and outputs 1024 nodes.
- A LeakyReLU layer with a slope of 0.2.
- 2 fully connected layers with 512 and 256 respectively, both followed by a Dropout layer with a probability of 0.4 and a LeakyReLU activation.
- 1 fully connected layer with 128 nodes.
- 1 output layer with 1 node follwed by a Sigmoid activation.

The forward pass should:
- Embed the label
- Concatenate the image and the embedded label.
- Pass the concatenated vector through the MLP.
- Return the output.


In [None]:
class ConditionalDiscriminator(nn.Module):
    def __init__(self, n_classes: int, n_channels: int, img_size: int):
        super(ConditionalDiscriminator, self).__init__()
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.img_size = img_size
        self.img_shape = (self.n_channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.n_classes, self.n_classes)
        self.adv_loss = torch.nn.BCELoss()

        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, image: torch.Tensor, labels: torch.Tensor):
        """Classifies an image given a label.

        Args:
            image (torch.Tensor): An image of shape (batch_size, n_channels, img_size, img_size)
            labels (torch.Tensor): A label vector of shape (batch_size)
        Returns:
            torch.Tensor: A classification score of shape (batch_size)
        """
        # YOUR CODE HERE
        raise NotImplementedError()
        return self.model(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [None]:
tmp_discriminator = ConditionalDiscriminator(10, 1, 28)
tmp_img = torch.randn(10, 1, 28, 28)
tmp_labels = torch.randint(0, 10, (10,))
tmp_score = tmp_discriminator(tmp_img, tmp_labels)
assert tmp_score.shape == (10, 1), "Wrong shape of classification score"

To train the model, we have provided you with a training loop below. However, feel free to play around with it.

In [None]:
class ConditionalGan:
    def __init__(
        self,
        device: str,
        data_loader: torch.utils.data.DataLoader,
        n_classes: int,
        n_channels: int,
        img_size: int,
        latent_dim: int,
    ):
        self.device = device
        self.data_loader = data_loader
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.generator = ConditionalGenerator(
            self.n_classes, self.n_channels, self.img_size, self.latent_dim
        )
        self.generator.to(self.device)
        self.discriminator = ConditionalDiscriminator(
            self.n_classes, self.n_channels, self.img_size
        )
        self.discriminator.to(self.device)
        self.optim_G = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.generator.parameters()),
            lr=1e-4,
            betas=(0.5, 0.999),
        )
        self.optim_D = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.discriminator.parameters()),
            lr=1e-4,
            betas=(0.5, 0.999),
        )

    def train(self, n_epochs: int, log_interval: int = 200):
        self.generator.train()
        self.discriminator.train()
        viz_noise = torch.randn(10, self.latent_dim).to(self.device)
        viz_labels = torch.arange(10).to(self.device)

        for epoch in range(n_epochs):
            batch_time = time.time()
            for batch_idx, (data, target) in enumerate(self.data_loader):
                data, target = data.to(self.device), target.to(self.device)
                batch_size = data.size(0)
                real_label = torch.full((batch_size, 1), 1.0, device=self.device)
                fake_label = torch.full((batch_size, 1), 0.0, device=self.device)

                # train generator
                self.generator.zero_grad()
                z_noise = torch.randn(batch_size, self.latent_dim, device=self.device)
                x_fake_labels = torch.randint(
                    0, self.n_classes, (batch_size,), device=self.device
                )
                x_fake = self.generator(z_noise, x_fake_labels)
                y_fake_g = self.discriminator(x_fake, x_fake_labels)
                g_loss = self.discriminator.loss(y_fake_g, real_label)
                g_loss.backward()
                self.optim_G.step()

                # train discriminator
                self.discriminator.zero_grad()
                y_real = self.discriminator(data, target)
                d_real_loss = self.discriminator.loss(y_real, real_label)
                y_fake_d = self.discriminator(x_fake.detach(), x_fake_labels)
                d_fake_loss = self.discriminator.loss(y_fake_d, fake_label)
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                self.optim_D.step()

                if batch_idx % log_interval == 0 and batch_idx > 0:
                    _, axs = plt.subplots(1, 10, figsize=(15, 15))
                    with torch.no_grad():
                        generated_images = self.generator(viz_noise, viz_labels)
                    for i in range(10):
                        axs[i].imshow(generated_images[i].squeeze(), cmap="gray")
                        axs[i].axis("off")
                    plt.show()

                    print(
                        "Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}".format(
                            epoch,
                            batch_idx,
                            len(self.data_loader),
                            d_loss.mean().item(),
                            g_loss.mean().item(),
                            time.time() - batch_time,
                        )
                    )

                    batch_time = time.time()

Define the training. Feel free to experiment with both the architecture and the hyperparameters, however, the model should be able to generate images that look like the ones in the dataset. 

In [None]:
MNIST_DATA_DIR = "./mnist_data"
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
EPOCHS = 30
IMG_SIZE = 32
BATCH_SIZE = 128
LATENT_DIM = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: {}".format(DEVICE))

print("Loading data...\n")
dataset = dset.MNIST(
    root=MNIST_DATA_DIR,
    download=False,
    transform=transforms.Compose(
        [
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    ),
)
assert dataset
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
print("Creating model...\n")
model = ConditionalGan(
    device=DEVICE,
    data_loader=dataloader,
    n_classes=10,
    n_channels=1,
    img_size=IMG_SIZE,
    latent_dim=LATENT_DIM,
)

# Train
model.train(n_epochs=EPOCHS, log_interval=200)

Now that we've trained the model, let's see how it performs. We can use the following function to generate images from the generator and see if they look similar to the ones in the dataset.

In [None]:

# show some real data 
images, labels = next(iter(dataloader))
images = images[:100]
labels = labels[:100]
fig, axs = plt.subplots(10, 10, figsize=(8, 8))
for i in range(10):
    for j in range(10):
        axs[i, j].imshow(images[i * 10 + j].squeeze(), cmap="gray")
        axs[i, j].axis("off")
plt.savefig(osp.join(OUTPUT_DIR, "real_gridview.png"))
plt.show()

# generate a 10x10 grid of images
noise = torch.randn(100, model.latent_dim).to(DEVICE)
with torch.no_grad():
    generated_images = model.generator(noise, labels)

fig, axs = plt.subplots(10, 10, figsize=(8, 8))
for i in range(10):
    for j in range(10):
        axs[i, j].imshow(generated_images[i * 10 + j].squeeze(), cmap="gray")
        axs[i, j].axis("off")
plt.savefig(osp.join(OUTPUT_DIR, "fake_gridview.png"))

print("Can you tell the difference?")


# Looking back
Now that you've also implemented a conditional GAN, to be able to generate a particular sample from the data distrubution. As of now, you know absolutely everything there is to know about GANs, and therefore, it is time to look back at the previous assignment on VAEs. Here, we want you to contrast GANs to VAEs in the following aspects:
1. What are they trained to do, respectively?
2. Latent space: How are the latent spaces of GANs and VAEs different? 
3. Mode collapse: Are they equally prone to mode collapse? If not, which one is more prone?
4. Applications: What are some applications of GANs and VAEs? What can you do with a VAE that is not suitable for a GAN?
5. Quality of samples: Which one tends generate better data?


YOUR ANSWER HERE