In [None]:
# Notebook Configuration
GPU_COUNT = 1
WINDOWS = True

In [None]:
!pip install -q pandas
!pip install -q seaborn
!pip install -q matplotlib
!pip install -q neptune-client

<img src="images/cifar10.PNG" width=700>

# Generative Adversarial Networks - CIFAR10
> Can we create a strong generative model for CIFAR10 using the GAN architecture?

## Background Research

### Introduction
The ultimate goal of Deep Learning is to be able to create a function that can <strong>effectively</strong> model any form of data distribution. History has time and time again displayed the impressive success of discriminators, models that learn to divide the data distribution/map a high dimensional vector to one that is lower (Goodfellow et al., 2014). 

<ul>
	<li>Generative Adversarial Networks</li>
	<li>Diffusion Models</li>
	<li>Variational AutoEncoders</li>
</ul>

### The makings of GANs

### Types of GANs

### The difficulty with GANs
Mode collapse
Stability - hyperparamter tuning is very important versus discriminator models where hyperparameters more often than not determine .

### Loss Functions
Different Loss functions, main one being KL Divergence.
Explain how KL Divergence works.
There are papers claiming that novel loss functions improve stability.
There are also papers https://arxiv.org/abs/1811.09567 that claim that loss functions don't really matter.
This is an area of research I will attempt to explore in the notebook as well.

## Developing the GAN
### Objectives
<ol>
	<li>Explore the CIFAR10 dataset</li>
	<li>Implement and evaluate to find the best performing model</li>
	<li>Analyse the final model</li>
</ol>

### Importing Libraries

### Utility Functions
We define some utility functions below that will ease and help us with our analysis.

In [None]:
def loc_data(data, loc):
	datacopy = copy.deepcopy(data)
	arr = np.array(datacopy.loc[loc].drop('label'))
	label = datacopy.loc[loc]['label']
	root = int(len(arr) ** 0.5)
	arr.resize((root, root))
	return label, arr

def imshow(arr: list, label: list = None, figsize=None, shape = (32, 32, 3), is_int = None):
	if is_int == None:
		if type(arr[0]) == torch.Tensor:
			is_int = (arr[0].detach().cpu().numpy() > 1).sum() > 0
		else:
			is_int = (arr[0] > 1).sum() > 0
	if label == None:
		label = [''] * len(arr)

	height = int(len(arr) ** 0.5)
	width = math.ceil(len(arr) / height)

	if figsize == None:
		fig = plt.figure()
	else:
		fig = plt.figure(figsize=figsize)
	for i in range(height):
		for j in range(width):
			ax = fig.add_subplot(height, width, i * height + j + 1)
			ax.grid(False)
			ax.set_xticks([])
			ax.set_yticks([])
			show = arr[i * height + j]
			if type(arr[i * height + j]) != torch.Tensor:
				show = torch.Tensor(show)
				# ax.imshow((arr[i * height + j].squeeze(0).cpu().permute(1, 2, 0) / 255).type(torch.uint8 if is_int else float))
			# if (show.shape[0] == 1):
			# 	ax.imshow((show.squeeze(0).cpu()).type(torch.uint8 if is_int else torch.float), cmap='gray')
			# else:
			if len(show.squeeze(0).cpu().shape) == 2:
				ax.imshow((show.squeeze(0).detach().cpu()).type(torch.uint8 if is_int else torch.float), cmap='gray')
			else:
				ax.imshow((show.squeeze(0).detach().cpu().permute(1,2,0)).type(torch.uint8 if is_int else torch.float))
			ax.set_title(label[i * height + j])

def df_to_tensor(df, shape = (28, 28)):
	return torch.tensor(df.values.reshape((-1, *shape)), dtype=torch.float32)

def preprocess(df):
	return df.copy() / 255

def mse(t1, t2, shape=(28, 28)):
	loss = nn.MSELoss(reduction='none')
	loss_result = torch.sum(loss(t1, t2), dim=2)
	loss_result = torch.sum(loss_result, dim=2)
	loss_result = loss_result / np.prod([*shape])
	return loss_result