# Module 6 - Fine-tuning ResNet toward plankton data

We have seen that a neural network that was trained on a completely plankton-unrelated dataset (like ImageNet) still produces features that allow the classification of plankton data.
Now, we can go a step further and *fine-tune* such a network to do plankton classification.
This is akin to teaching a person without prior oceanographic experience how to recognize different types of fish, assuming that they are able to recognize other kinds of objects.

In practice, CNNs are almost always fine-tuned (and not trained from scratch) for convergence reasons.

In [None]:
import copy
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from sklearn.metrics import classification_report, confusion_matrix
from torch.optim import lr_scheduler
from torch.utils.data import RandomSampler
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from tqdm import tnrange, tqdm_notebook

from utilities.display_utils import imshow_tensor
from utilities.split import stratified_random_split

DATASET_PATH = "/data1/mschroeder/Datasets/19-05-11 ZooScanNet/ZooScanSet/imgs"

## Data loading and transformation

Image datasets can conveniently loaded with [`torchvision.datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder).
It assumes one folder for each class where the images are located.

CNNs have a fixed input size. ResNets happen to be trained with 224x244 images. 
Therefore, we need to make sure that each image has the correct dimensions.
`ImageFolder` has a `transform` parameter for that.
After resizing, the images need to be converted to a PyTorch [`Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor).

In [None]:
transform = Compose([
    # Resize every image to a 224x244 square
    Resize((224,224)),
    # Convert to a tensor that PyTorch can work with
    ToTensor()
])

# Images are located at at {DATASET_PATH}/{class_name}/{objid}.jpg
dataset = ImageFolder(DATASET_PATH, transform)

Now let's look at the first example.

In [None]:
# Extract the tensor and the label of the first example
tensor, label = dataset[0]

print("Class: {:d} ({})".format(label, dataset.classes[label]))
imshow_tensor(tensor)

## Training / validation sets
Supervise the training using a validation set.

In [None]:
dataset_train, dataset_val = stratified_random_split(dataset, test_size=0.2)
print("{:,d} training examples.".format(len(dataset_train)))
print("{:,d} validation examples.".format(len(dataset_val)))

## Preparing the model

We start with a pre-trained ResNet18 model.
It was initially trained on ImageNet which happens to contain 1000 classes. However, our plankton dataset contains XXX classes. Therefore, we have to reset the classifier layer to the correct number of classes.

In [None]:
model = models.resnet18(pretrained=True)

# get the number of features that are input to the fully connected layer
num_ftrs = model.fc.in_features

# reset the fully connect layer
model.fc = nn.Linear(num_ftrs, len(dataset.classes))

# Transfer model to GPU
model = model.cuda()

## Preparing the optimizer

We will train the network using [Stochastic Gradient Descend (SDG)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent).
In each iteration, the network parameters are updated in order to minimize a training criterion, in our case the [Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy) Loss.
The better the predictions, the smaller the loss.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Train

In [None]:
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128,
                                           shuffle=True, num_workers=4)

In [None]:
# Train for 2 epochs
for epoch in range(2):
    with tqdm_notebook(loader_train, desc="Training Epoch #{:d}".format(epoch + 1)) as t:
        for inputs, labels in t:
            # Copy data to GPU
            inputs = inputs.cuda()
            labels = labels.cuda()
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            t.set_postfix(loss=loss.item())

print('Finished Training')

## Evaluate

Let's see how well our model performs.

First, display some examplary images together with their ground-truth and predicted labels.

In [None]:
# A data loader for the validation set with a batch size of 4 for demonstration purposes
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=4, num_workers=4)

# Extract one batch
images, labels = next(iter(loader_val))

# Show images of the batch
imshow_tensor(torchvision.utils.make_grid(images))
print('Ground truth:', ', '.join('%5s' % dataset.classes[labels[j]] for j in range(4)))

# Run the batch through the model
outputs = model(images.cuda())

# Collect the predicted classes
_, predicted = torch.max(outputs, 1)

print('Predicted:', ', '.join('%5s' % dataset.classes[predicted[j]]
                              for j in range(4)))

Now we do a thorough evaluation of the whole dataset. In order to do that, we need to run the whole validation set through the network and record the predictions.

In [None]:
dataset_val_small = torch.utils.data.Subset(dataset_val, range(10000))
loader_val_small = torch.utils.data.DataLoader(dataset_val_small, batch_size=128)

labels_true = []
labels_predicted = []

# We don't need to calculate gradients
with torch.no_grad():
    with tqdm_notebook(loader_val_small, desc="Evaluating") as t:
        for inputs, labels in t:
            # Copy data to GPU
            inputs = inputs.cuda()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            labels_true.extend(labels.tolist())
            labels_predicted.extend(predicted.tolist())


print(classification_report(labels_true,
                            labels_predicted,
                            labels=np.arange(len(dataset.classes)),
                            target_names=dataset.classes))

## Exercises

1. Compare the results to the previous classifiers.
2. Try different [transformations](https://pytorch.org/docs/stable/torchvision/transforms.html).
3. Try a different [model](https://pytorch.org/docs/stable/torchvision/models.html).
4. What do you need to change to use a different dataset?

## Conclusion

In this module, you learned how to use a folder of images to fine-tune a model in PyTorch.

***What else?***