In [None]:
# ==================== #
# Author: Kenneth Chen #
# Student ID: 2100072  #
# ==================== #

# Notebook Configuration
GPU_COUNT = 1
WINDOWS = True

In [None]:
!pip install -q pandas
!pip install -q seaborn
!pip install -q matplotlib
!pip install -q neptune-client

<img src="images/cifar10.PNG" width=700>

# Generative Adversarial Networks - CIFAR10
> Can we create a strong generative model for CIFAR10 using the GAN architecture?

## Background Research 📖

### Introduction 💡
The ultimate goal of Deep Learning is to be able to create a function that can <strong>effectively model any form of data distribution</strong>. History has time and time again displayed the impressive success of discriminators, models that learn to divide the data distribution/map a high dimensional vector to one that is lower (Goodfellow et al., 2014). For instance, tasks such as Image Classification are one type of discriminative modelling, as the high dimensional images are mapped into low dimensional probabilities of labels.

What about generative modelling? In generative modelling, the goal is instead given a data distribution to learn from, to produce or <strong>generate new examples</strong> that follow this distribution but still aim to be unique. Thus, a high performing generative model should be able to create examples that are both <strong>plausible</strong> (in that one can recognize what the generated example is supposed to be of) <strong>and indistinguishable</strong> from real data examples (Brownlee, 2019). Generative models can be Unsupervised or Semi-Supervised, depending on the exact task that one is trying to tackle. 

There are different approaches to network architectures when it comes to trying to achieve Generative Models:
<ul>
	<li>Generative Adversarial Networks (GAN) </li>
	<li>Diffusion Models</li>
	<li>Variational Auto Encoders (VAE)</li>
</ul>

<strong>GANs ⚔️</strong> <br />
GANs are the main focus of this notebook. Proposed by Ian Goodfellow in 2014, it became one of the more popular types of Generative Models used. For instance, the commonly known website <a href="thispersondoesnotexist.com">thispersondoesnotexist.com</a> uses the StyleGAN2 architecture (Karras et al., 2020), to generate high fidelity images of humans. The idea for GANs is that <strong>there are two networks that work against each other in a game</strong>, where they try to one up each other. Thus, this leads to improvement in both networks. More details are discussed under "What's inside a GAN? 🔍".

<strong>Diffusion Models ✨</strong><br /> 
These are the models that have been not only been successful, but widely popular as well. For instance, OpenAI Dall-E, Google Imagen, Stable Diffusion, Midjourney are models that fall under the category of Diffusion Models. (Muppalla and Hendryx, 2022). From a high level, it works like so:
<ul>
	<li>Noise is added to original images</li>
	<li>Noise is procedurally added until image is all noise</li>
	<li>The model then learns to remove the noise</li>
	<li>Guidance can be added in the form of e.g. text-to-image, to provide direction of the generation process</li>
</ul>

<img src="images/diffusion1.jpg" width=400/><br />
*Noise is procedurally added to the image. Image Credit: (Muppalla and Hendryx, 2022)*

<img src="images/diffusion2.jpg" width=400/><br />
*Model attempts to recreate the image. Image Credit: (Muppalla and Hendryx, 2022)*

<strong>Variational Auto Encoders 🎲</strong><br />
We first take a look at what Auto Encoders are. An Auto Encoder is trained for it to learn to copy the input to the output (Goodfellow, Bengio and Courville, 2016). This is done by having an <strong>encoder map the image</strong> to a compressed representation of the image (the inner nodes), to which the decoder <strong>uses this compressed representation</strong> to generate an image similar to the original. Note that the trick is to <strong>restrict the number of inner nodes</strong> inside the network, such that it is <strong>not able to generate a 1-to-1 copy</strong>. This way, it's forced to learn the most promiment of features to recognize, using the limited number of nodes.

The idea with Variational Auto Encoders is that <strong>instead of the encoder just mapping the image to a compressed representation</strong> (aka latent vector), we instead <strong>learn the distribution that the latent vector can take on</strong>. Using this, we can then randomly sample from the learned latent distribution, for the decoder to give us a newly and controlled generated image.

<img src="images/vae.PNG" width=400/><br />
*VAE Architecture. We note the learning of the latent distribution. Image Credit: (Rocca, 2019)*


### What's inside a GAN? 🔍
In a GAN, there are in fact two networks, a <strong>generator</strong> and a <strong>discriminator</strong> that improve each other by competing in a game scenario (Goodfellow, Bengio and Courville, 2016). The aim is to use the well established field of discriminators to assist the generator. The goal of the generator is to <strong>create realistic images</strong> that appear to be from the distribution of the training images, where as the goal of the discriminator is to determine <strong>if a given image is from the data distribution</strong>. The process goes as follows:
<ol>
	<li>Generator creates images</li>
	<li>Discriminator learns to distinguish real vs fake from a set of real images and these newly generated images</li>
	<li>Using the updated Discriminator, Generator learns to trick to trick the Discriminator</li>
</ol>


### Types of GANs 🍐
There many different types of GANs, however I believe the most differing pair is the Vanilla GAN and the Conditional GAN. The Vanilla GAN is what was proposed by Ian Goodfellow in 2014, which consists of the basic architecture with multi-layer perceptrons (MLPs). Conditional GANs are different in the aspect that one can provide additional information to the model, which could be thought of as a form of guidance similar to diffusion models.

<img src="images/vanilla.jpg" width=400><br/>
*Vanilla GAN architecture (Tewari, N.d.)*

<img src="images/conditional.jpg" width=400><br/>

*Conditional GAN architecture (Tewari, N.d.)*

We observe that there is an extra component of `y`, which represents the extra information presented both to the Generator and Discriminator. This extra information is usually in the form of class labels to allow one to possess control over the output, but it can be extended to different modal data, even something such as text (in which case it needs some sort of text processor).

One may think of Conditional GAN as a *"supervised" version* of Vanilla GANs. 

### Uses of GANs 🧤

As a GAN is a generative model, there are a large number of applications of GANs (Brownlee, 2019). Personally, I find it interesting how the idea of GANs can be adjusted for any modal of data, as long as the architecture for the encoder and decoders are adjusted accordingly. Here are a few areas of GANs I believe are quite intriguing:
<ul>
	<li>Time Series</li>
	<li>Image Generation</li>
	<li>Music Generation</li>
	<li>Audio Generation</li>
	<li>Style Transfer (e.g. winter photo to summer photo, jazz to classical music)</li>
</ul>

Among them, I think Audio Generation stands out to me the most. The idea of using GANs in music composition sounds like a difficult challenge, but also an impressive feat if one could pull it off.

### The difficulty with GANs 🧩
GANs are 
Convex.
 
Mode collapse
Stability - hyperparameter tuning is very important versus discriminator models where hyperparameters more often than not determine .

### Loss Functions 🏓
Different Loss functions, main one being KL Divergence.
Explain how KL Divergence works.
There are papers claiming that novel loss functions improve stability.
There are also papers https://arxiv.org/abs/1811.09567 that claim that loss functions don't really matter.
This is an area of research I will attempt to explore in the notebook as well.

## Developing GAN 💻
### Objectives 🖊️
We identify the tasks and objectives we want to meet, which will be used as a guide throughout the development of our GAN.
<ol>
	<li>Explore the CIFAR10 dataset</li>
	<li>Implement and evaluate to find the best performing model</li>
	<li>Analyse the final model</li>
</ol>

### Importing Libraries
The necessary libraries are imported below.

In [None]:
import torch
import torchvision

import copy
import numpy as np
import matplotlib.pyplot as plt

import math

### Utility Functions
We define some utility functions below that will ease and help us with our analysis.

In [None]:
def loc_data(data, loc):
	datacopy = copy.deepcopy(data)
	arr = np.array(datacopy.loc[loc].drop('label'))
	label = datacopy.loc[loc]['label']
	root = int(len(arr) ** 0.5)
	arr.resize((root, root))
	return label, arr

def imshow(arr: list, label: list = None, figsize=None, shape = (32, 32, 3), is_int = None):
	if is_int == None:
		if type(arr[0]) == torch.Tensor:
			is_int = (arr[0].detach().cpu().numpy() > 1).sum() > 0
		else:
			is_int = (arr[0] > 1).sum() > 0
	if label == None:
		label = [''] * len(arr)

	height = int(len(arr) ** 0.5)
	width = math.ceil(len(arr) / height)

	if figsize == None:
		fig = plt.figure()
	else:
		fig = plt.figure(figsize=figsize)
	for i in range(height):
		for j in range(width):
			ax = fig.add_subplot(height, width, i * height + j + 1)
			ax.grid(False)
			ax.set_xticks([])
			ax.set_yticks([])
			show = arr[i * height + j]
			if type(arr[i * height + j]) != torch.Tensor:
				show = torch.Tensor(show)
				# ax.imshow((arr[i * height + j].squeeze(0).cpu().permute(1, 2, 0) / 255).type(torch.uint8 if is_int else float))
			# if (show.shape[0] == 1):
			# 	ax.imshow((show.squeeze(0).cpu()).type(torch.uint8 if is_int else torch.float), cmap='gray')
			# else:
			if len(show.squeeze(0).cpu().shape) == 2:
				ax.imshow((show.squeeze(0).detach().cpu()).type(torch.uint8 if is_int else torch.float), cmap='gray')
			else:
				ax.imshow((show.squeeze(0).detach().cpu().permute(1,2,0)).type(torch.uint8 if is_int else torch.float))
			ax.set_title(label[i * height + j])

def df_to_tensor(df, shape = (28, 28)):
	return torch.tensor(df.values.reshape((-1, *shape)), dtype=torch.float32)

def preprocess(df):
	return df.copy() / 255