# CIFAR10: Training a classifer with **PyTorch-lightning**

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lento234/ml-tutorials/blob/main/01-basics/03-CIFAR10_pl.ipynb)

**References**:
- https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
- https://pytorch-lightning.rtfd.io/en/latest/


**Runtime setup: GPU accelerator at Google colab:**

1. On the main menu, click **Runtime** and select **Change runtime type**. 
2. Select **GPU** as the hardware accelerator.

![steps](../images/steps.png)

In [None]:
!nvidia-smi

**Table of content**

1. [Load and pre-process the dataset](#load)
2. [Define the CNN model **+ training step + loss + optimizer**](#define)
3. [Setup the **trainer**](#trainer)
4. [Train **and validate** the model on **train** and **test** dataset](#train)
5. [Assess training with **tensorboard**](#tensorboard)
6. [Test the model](#validate)

**CIFAR10 Dataset**

The dataset consists of `3x32x32` images of 10 difference classes:

    airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.

![cifar10](../images/cifar10.png)

## Setup

Lightning is easy to install. Simply ```pip install pytorch-lightning```

In [None]:
!pip install pytorch-lightning --quiet

### Environment

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets

import pytorch_lightning as pl

In [None]:
mpl.style.use('seaborn-poster')
mpl.rcParams['mathtext.fontset'] = 'cm'
mpl.rcParams['figure.figsize'] = 5 * np.array([1.618033988749895, 1])

In [None]:
pl.seed_everything(234)

### Hyper-parameters

In [None]:
batch_size = 32
num_workers = 4
num_epochs = 5
learning_rate = 0.001
momentum = 0.9

<a id='load'></a>
## 1. Load and pre-process data

- Define preprocessing algorithm
- Load training and test dataset

### 1.1 Define preprocessing algorithm

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), # convert data to pytorch tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize dataset for each channel
])

### 1.2 Load training and test dataset

In [None]:
# Download train and test dataset
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                 download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, 
                                 download=True, transform=transform)

# Dataset sampler (shuffle, distributed loading)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                           shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 
                                          shuffle=False, num_workers=num_workers)

print(f"num. examples: train = {len(train_dataset)}, test = {len(test_dataset)}")

In [None]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck'])

num_classes = len(classes)

In [None]:
def imshow(images, labels):    
    plt.figure(figsize=(10,10))
    for i in range(16):
        plt.subplot(4, 4,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        img = images[i] / 2 + 0.5 # unnormalize
        plt.imshow(np.transpose(img.numpy(), (1, 2, 0)), cmap=plt.cm.binary)
        plt.xlabel(classes[labels[i]])
    plt.show()
    
# get some random training images
images, labels = next(iter(train_loader))

# show images
imshow(images, labels)

<a id=define></a>
## 2. Define the CNN model **+ training step + loss + optimizer**

![network_architecture](../images/network_architecture.png)

**Architecture:**

- Input: An image of `n_channels=3`.
- Two layer stacks of 2D convolutional layers (`Conv2d` with `kernel_size=5`) with rectified linear activation (`ReLU`) followed by a  2D max pooling (`MaxPool2D` with `kernel_size=2` and `stride=2`)
- Three layer stacks of Fully-connected layers (`Linear`) with ReLU activaton.
- Output: 10-dimensional vector defining the activation of each class

In [None]:
class Net(pl.LightningModule):
    def __init__(self, **kwargs):
        super(Net, self).__init__()
        
        # save hyper-parameters
        self.save_hyperparameters()
        
        self.example_input_array = torch.ones(1, self.hparams.num_channels, 32, 32)
        
        # Define network
        self.layer1 = nn.Sequential(nn.Conv2d(self.hparams.num_channels, 6, kernel_size=5),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 =  nn.Sequential(nn.Conv2d(6, 16, kernel_size=5),
                                     nn.ReLU(),
                                     nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(nn.Flatten(),
                                    nn.Linear(16 * 5 * 5, 120),
                                    nn.ReLU())
        self.layer4 = nn.Sequential(nn.Linear(120, 84),
                                    nn.ReLU())
        self.layer5 = nn.Linear(84, self.hparams.num_classes)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x_train, y_train = batch
        y_pred = self(x_train)
        loss = F.cross_entropy(y_pred, y_train)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) # logging
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_test, y_test = batch
        y_pred = self(x_test)
        loss = F.cross_entropy(y_pred, y_test)
        self.log('val_loss', loss)
        
    def test_step(self, batch, batch_idx):
        x_test, y_test = batch
        y_pred = self(x_test)
        loss = F.cross_entropy(y_pred, y_test)
        self.log('test_loss', loss)
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(),
                               lr=self.hparams.learning_rate,
                               momentum=self.hparams.momentum)

In [None]:
# Construct model
model = Net(
    num_channels=3,
    num_classes=num_classes,
    learning_rate=learning_rate,
    momentum=momentum
)

<a id=trainer></a>
## 3. Setup the **trainer**

In [None]:
# GPU trainer
trainer = pl.Trainer(
    gpus=1,
    max_epochs=num_epochs,
    progress_bar_refresh_rate=50,
)

**Additional flags:**
```python
            log_gpu_memory='all', # gpu stats
            profiler=True, # profiling stats
            precision=16, # half-precision
            deterministic=True # reproducability
            accelerator='ddp' # distributed data parallelism
            benchmark=True # cudnn benchmark and optimizing
            callbacks=[custom_callback_one(), custom_callback_two()]
            fast_dev_run=True # dev run for debugging all the hooks
```
More info: https://pytorch-lightning.readthedocs.io/en/stable/trainer.html

<a id=train></a>
## 4. Train **and validate** the model on **train** and **test** dataset

In [None]:
trainer.fit(model, train_loader, test_loader)

<a id=tensorboard></a>
## 5. Assess training with **tensorboard**

In [None]:
# Start tensorboard
%reload_ext tensorboard
%tensorboard --logdir lightning_logs

<a id=validate></a>
## 6. Test the model on **test** dataset

In [None]:
trainer.test(model, test_loader)

<a id=bonus></a>
## 7. **Bonus**: Exercises

### 7.1. Add a logging for training / validation `accuracy`


See previous example (**section 5.2**)

$$ \mathrm{Accuracy} = \frac{\sum \mathrm{True\ positive} + \sum\mathrm{True\ negative}}{\sum \mathrm{Classes}} $$

**Hint:**
```python
    def __init__(self):
        ...
        self.accuracy = pl.metrics.Accuracy()
    
    def validation_step(...):
        ...
        self.log("val_acc", self.accuracy(y_hat, y))
    
```

**Reference:** 
- https://en.wikipedia.org/wiki/Accuracy_and_precision

### 7.2. Add learning rate scheduler

**Hint:**
```python
    # Adam + LR scheduler
    def configure_optimizers(self):
        optimizer = Adam(...)
        scheduler = LambdaLR(optimizer, ...)
        return [optimizer], [scheduler]
```

**Reference:**
- https://pytorch-lightning.readthedocs.io/en/stable/optimizers.html?highlight=scheduler#learning-rate-scheduling
- https://pytorch.org/docs/stable/optim.html

### 7.3. Add early stopping

**Hint:**
```python

    from pytorch_lightning import Trainer
    from pytorch_lightning.callbacks import EarlyStopping
    early_stopping = EarlyStopping('val_loss')
    trainer = Trainer(callbacks=[early_stopping])
```

**References:**
- https://pytorch-lightning.readthedocs.io/en/stable/generated/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping

### 7.4. Log figures into tensorboard

**Hint:**
```python
    def validation_epoch_end(...):
        ...
        self.logger.experiment.add_figure(
            'val_acc', fig, self.current_epoch)
```

**References:**
- https://pytorch.org/docs/stable/tensorboard.html
- https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loggers.tensorboard.html?highlight=tensorboard