This notebook was inspired by neural network & machine learning labs led by [GMUM](https://gmum.net/).

# Image classification
Your task today will be to implement a CNN for image classification.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.utils import make_grid

## CIFAR-10
The [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset consists of 60000 $32\times32$ colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

In [None]:
bs = 6

dataset = CIFAR10(root='.', train=True, transform=ToTensor(), download=True)
loader = DataLoader(dataset, batch_size=bs, shuffle=True)

In [None]:
def imshow(img, func=lambda x: x):
    img = func(img)
    img = img / 2 + 0.5
    img_np = img.numpy()
    plt.imshow(np.transpose(img_np, (1, 2, 0)))
    plt.axis('off')
    plt.show()

images, labels = next(iter(loader))

imshow(make_grid(images))
print(' '.join(f'{dataset.classes[labels[j]]:6}' for j in range(bs)))

## Task 1 (1.25p)
Implement a convolutional neural network for multiclass classification on CIFAR-10 from scratch. You need to implement both the model and the training loop. Your code should report the loss (during training and testing) and accuracy on the test set (optionally also on the training set). You should achieve 75% accuracy on the test set. You can use any features present in PyTorch.

Some tips:
- Change the runtime to GPU in Google Colab. You need to use `torch.device('cuda')` to train on the GPU (don't forget to send your model and data to the device!).
- Your model should inherit from [`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html).
- In order to have a short feedback loop, when choosing the architecture or hyperparameters don't train for more than a couple epochs.
- Visualize the loss curve to see when it begins to flatten and whether the model is overfitting (by comparing to the loss for the training set).

In [None]:
# implement your model here

In [None]:
# load in the data here

train_dataset = CIFAR10(root='.', 
                        train=True,
                        download=True,
                        transform=???)

test_dataset = CIFAR10(root='.', 
                       train=False,
                       download=True,
                       transform=???)


train_loader = DataLoader(train_dataset, batch_size=???, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=5000, shuffle=False, num_workers=2)

In [None]:
# implement your training loop with all the needed hyperparameters here

## Data augmentation
Data augmentation is a technique to synthetically increase the amount of data by adding slightly modified copies of already existing data. The `torchvision` package contains many transformations useful for this task.

In [None]:
from torchvision.transforms import RandomVerticalFlip

flip = RandomVerticalFlip(p=1) 

imshow(make_grid(images))
imshow(make_grid(images), flip)
print(' '.join(f'{dataset.classes[labels[j]]:6}' for j in range(bs)))

## Task 2 (0.25p)
Utilizing the functionality from [`torchvision.transforms`](https://pytorch.org/docs/stable/torchvision/transforms.html) add augmentations to the training set and see how much that improves the accuracy of your model from the previous task. You can search the internet for typical augumentations for CIFAR-10.

In [None]:
# load in the data with the augumentations here

train_dataset = CIFAR10(root='.', 
                        train=True,
                        download=True,
                        transform=???)

train_loader = DataLoader(train_dataset, batch_size=???, shuffle=True, num_workers=2)

In [None]:
# rerun the training loop from the previous task here