# Exercise: MNIST
by Tobias Jülg

Disclaimer: This exercise is partly based on pytorch's [Quickstart Tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html).

## Install PyTorch and PyTorch Lightning
Have a look at the [pytorch's installation page](https://pytorch.org/get-started/locally/). If the CPU version is sufficient for you then running the following cell should be sufficient. Note that pytorch only works with python3.9 so far.

In [None]:
# uncomment to install
#!pip install numpy
#!pip install torch
#!pip install torchvision
#!pip install pytorch_lightning

## Exercise 1: Plain PyTorch
### Imports

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

### Data
In this exercise we use the MNIST dataset which consists of 28x28 pictures of handwritten digits. The task is to classify the pictures into the numbers  0 to 9. PyTorch already has a Dataset implemented for MNIST so we dont have to perform this yourselfs. Conveniently, the dataset can also download the dataset for us. The following code does exactly that.

1. Look up what **transforms** are in the [pytorch docu](https://pytorch.org/vision/stable/transforms.html) and especially check out what the `ToTensor()` transform does. Do we have to normalize our data before we put it into the network?
2. Create two dataloaders for our train and validation datasets (TODO2)


In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

val_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = #TODO2: add dataloaders
val_dataloader = #TODO2

### Network
3. Checkout the shapes of our training data. How must the input of your network look like to take such data?

In [None]:
# TODO 3: checkout shapes

4. Complete the network's code using the following topology using the [`nn.Sequential` module](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) (TODO4):

$$W_3 ReLU(W_2 ReLU(W_1 x + b_1) + b_2) + b_3$$

where $W_1\in\mathbb{R}^{512\times28\cdot28}$, $b_1\in\mathbb{R}^{512}$,
$W_2\in\mathbb{R}^{512\times512}$, $b_2\in\mathbb{R}^{512}$,
$W_3\in\mathbb{R}^{10\times512}$, $b_3\in\mathbb{R}^{10}$

What does the [`nn.Flatten()` layer](https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html) do?

In [None]:
class ClassicNN(nn.Module):
    def __init__(self):
        super(ClassicNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            # TODO 4: add layers as described above
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

We now want to create a second network which uses convolutions instead of fully connected layers. Convolutions can be seen as learnable filters. They are much better in detecting local patterns as the weight kernel "slides" over the image and uses the same weights over and over again. This means that they will also result in less trainable parameters. Thus, they are perfect feature extractors for images. [This article](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53) will give you a good overview of how convolutions work if you are new to the topic.

5. Look up how [`nn.Conv2d` layers](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) work. As the handwritten digits are just in grayscale, they only have one color channel, as you should know from question 3. So you should start with channel dimension 1 then go up to 16 in the first convolution, then to 32 in the second. Use a kernel size of 2, stride of 2 and padding of 1 for the convolutions. Use the `nn.ReLU` activation function in between the layers. Finally use a `nn.Flatten` layer and a fully connected (linear) layer to get the output down to our 10 output neurons.

The following formula might come in handy to calculate the input size of the fully connected layer:

$$H_{out} = \lfloor\frac{H_{in} + 2*padding - (kernel-1)-1}{stride}+1\rfloor$$

where $H_{in}$ is the channel dimension of the layer before and $H_{out}$ is the output dimension of the current layer. What is the output dimension of the second convolution?

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.cnn_relu_stack = nn.Sequential(
            # TODO 5: add layers as described above
        )

    def forward(self, x):
        logits = self.cnn_relu_stack(x)
        return logits

6. Which loss function should we use this kind of problem.
7. Given that we want to use `nn.CrossEntropyLoss()` as loss. Do we need to add a softmax layer in the end? Look up the documentation of the loss function for your answer.
8. The class label comes in the form of a single number {0, ..., 9}. Look up what a one-hot vector is. Given that we want to use `nn.CrossEntropyLoss()` as loss, do we need to convert our ground truth labels to one-hot vectors? A: No as `nn.CrossEntropyLoss()` also handles this and does not need one-hot vectors.
9. Complete the loss function in the code snipped below (TODO 9)
10. Compltete the train_loop by adding the forward pass, the loss function calculation and the optimizer code (TODO 10)
11. Why is the val_loop function inefficient? Add according code to make it more efficient. (TODO 11)
12. Test you code, you might also want to change the model to your CNN model from above.

In [None]:
model = CNN()
learning_rate = 1e-3
batch_size = 64
epochs = 5
loss_fn = #TODO 9: add the loss function


optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # TODO 10: Compute prediction and loss

        # TODO 10: Backpropagation and optimizer step

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def val_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    val_loss, correct = 0, 0

    # TODO 11: why is this code inefficient, what is missing here?
    for X, y in dataloader:
        pred = model(X)
        val_loss += loss_fn(pred, y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    val_loss /= num_batches
    correct /= size
    print(f"Val Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")
    
    
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    val_loop(val_dataloader, model, loss_fn)
print("Done!")

## Exercise 2: PyTorch Lightning
### Imports

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

1. Complete the missing validation step: After the forward pass and the loss function: Log the loss and the accuracy. You might need to use the `validation_epoch_end` method to calculate the validation accuracy after the epoch.

In [None]:
class PLModule(pl.LightningModule):
    def __init__(self, model, data, hparms):
        super().__init__()
        self.model = model
        self.crit = nn.CrossEntropyLoss()
        self.hparams.update(hparams)
        self.data = data
        
    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        y_hat = self.model(x)
        loss = self.crit(y_hat, y)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        # TODO 1: forward pass, loss calcuation and loss logging
    
    def validation_epoch_end(self, validation_step_outputs):
        # TODO 1: calculate accuracy from validation_step outputs
        pass
        
    def train_dataloader(self):
        return DataLoader(self.data[0], batch_size=self.hparams["batch_size"])

    def val_dataloader(self):
        return DataLoader(self.data[1], batch_size=self.hparams["batch_size"])


    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(),
                                self.hparams["learning_rate"])



2. Create the `hparams` dictionary containing our `learning_rate` (0.001), the `batch_size` (64) and the amount of `epochs` that you want to train. This dictionary is passed to the pytorch lightning module.

In [None]:
hparams = # TODO 2: add values in the hparams dict

In [None]:
model = ClassicNN()
pl_module = PLModule(model, (training_data, val_data), hparams)
trainer = pl.Trainer(
    #callbacks=callbacks,
    max_epochs=hparams["epochs"],
    deterministic=True,
    #gpus=[0],
    #profiler="simple",
)
trainer.fit(pl_module)

In [None]:
# Visualize the logged data in tensorboard:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

Possible further steps to go down the rabbit hole:
- Train more epochs
- Play around with the two differnt models. Extend the models with your own ideas.
- Add early stopping and model checkpoints
- If you have a GPU, try out GPU training
- Look up the [BatchNorm2d](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html) layer and add it to your networks. How can batchnorm help us in having a more stable training?
- Research how we can use [TorchMetrics](https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html) to simpify metric logging
- Replace SGD with the Adam optimizer and add weight decay to combat overfitting
- profile the training to find out how much faster GPU training is. What takes the most time in your training?
- Load the data with `num_workers=6` to fully utilze your CPU
- Visualize the wrongly classified images. Why do you think they are wrongly classified?
- Use a different dataset e.g. FashionMNIST. Do you achive a better accuracy? Why could there be a difference?
- Look up transfer learning. How does pytorch support transfer learning? Use a pretrained VGG net
- Look up ResNet. Explain the concept of skip connections and why they are usful when training very deep neural networks. Implement a skip connection layer yourself.