In [None]:
%load_ext lab_black
%matplotlib inline

# Transfer Learning <br>
Training a model from scratch can be time consuming and computationaly heavy. <br>
In this notebook we look at how we can take a network trained on one dataset and use the learned weights as a step up, allowing us to achieve good results with little effort.<br>
We will also look at techniques like data augmentation and learning rate decay to improve model performance.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.models as models
import time
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
batch_size = 64  # size of our mini batches
num_epochs = 20  # How many itterations of our dataset
learning_rate = 1e-4  # optimizer learning rate
start_epoch = 0  # initialise what epoch we start from
best_valid_acc = 0  # initialise best valid accuracy
image_size = 96  # what to resize our images to

In [None]:
save_checkpoint = False
start_from_checkpoint = False
save_dir = "models"
model_name = "Res_18_STL10"

In [None]:
# Set device to GPU_indx if GPU is avaliable
# GPU_indx = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.device(device)
n_workers = 4 * torch.cuda.device_count()

# Some preprocess to the dataset. eg: Convert the images to tensor

In [None]:
# Prepare a composition of transforms
# all models from the Pytorch model Zoo where trained using images normalised with
# the mean and std (one per channel) of the whole ImageNet Dataset
# therefore the pretrained feature "detectors" of the model will expect the input to be normalized in the same way
# https://pytorch.org/docs/stable/torchvision/models.html
transform1 = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Data Augmentation Transform<br>
After training ResNet with the above transform record the results, then implement the transform with data augmentation below <br>
With a small dataset our large model will more then likely simply overfit to (or memorize) the training data which will often lead to bad evaluation results<br>
We can "create more" data from our limited dataset by applying random transformations as we sample images from our dataset instead of simply resizing them<br>
By applying these transformations we are also forcing our model to generalise better to unseen images<br>
You can also apply random affine transformations (shifts, scaling, rotations etc) - see Pytorch documentations <br>
NOTE: you should only apply transforms that make sense, eg if at test time you'll never see an upside-down cat, don't flip your images vertically 


In [None]:
# Prepare a composition of transforms
# Replace the Resize in the above transform with two random transforms
# https://pytorch.org/docs/stable/torchvision/transforms.html
transform2 = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.5),
    ]
)

# Create the training, testing and validation data 

In [None]:
# Define our STL10 Datasets
# https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.STL10
# Dataset definition is a bit differenet to MNIST and CIFAR10
# STL10 has 3 different datasets, test, train and unlabeled
# http://ai.stanford.edu/~acoates/stl10/
# training set only has 5000 images and test set only 8000
# Image size in this dataset are 96x96, larger then what we've been using
# Try using transform 1 or 2 for the training set!! Only use transform1 for the test set!!
data_dir = "./data"  # where to load/save the dataset from
train_data = torchvision.datasets.STL10(
    root=data_dir, split="train", transform=transform2, download=True
)
test_data = torchvision.datasets.STL10(
    root=data_dir, split="test", transform=transform1, download=True
)

# Split trainging data into train and validation set with 90/10% traning/validation split
validation_split = 0.9

n_train_examples = int(len(train_data) * validation_split)
n_valid_examples = len(train_data) - n_train_examples
train_data, valid_data = torch.utils.data.random_split(
    train_data,
    [n_train_examples, n_valid_examples],
    generator=torch.Generator().manual_seed(42),
)

# Check the lengths of all the datasets

In [None]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

# Create the dataloader

In [None]:
# Create the training, Validation and Evaluation/Test Datasets
# It is best practice to separate your data into these three Datasets
# Though depending on your task you may only need Training + Evaluation/Test or maybe only a Training set
# (It also depends on how much data you have)
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=n_workers,
    shuffle=True,
    pin_memory=False,
)
valid_loader = DataLoader(
    valid_data,
    batch_size=batch_size,
    num_workers=n_workers,
    shuffle=True,
    pin_memory=False,
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)

In [None]:
# This Function will allow us to scale an image's pixel values to a value between 0 and 1
# It will undo the Normalisation that the Dataset performs
def normalize_img(img):
    mins = img.min(0, keepdims=True).min(1, keepdims=True)
    maxs = img.max(0, keepdims=True).max(1, keepdims=True)
    return (img - mins) / (maxs - mins)

# Visualise the data <br>
It is always important to fully understand what you are training your network with

In [None]:
plt.figure(figsize=(20, 10))
images, labels = next(iter(train_loader))
out = torchvision.utils.make_grid(images[0:8])
plt.imshow(normalize_img(out.numpy().transpose((1, 2, 0))))

# Create the pretrained network <br>
First train the ResNet from scratch and collect the results for the training and evaluation accuracy and training time<br>
Next set pretrained=True and train again collecting the results again <br>
Next uncomment out the commented lines of code, this will stop the optimiser from updating the pretrained parts of the network 
<br><br>
<b> Weight "Freezing"</b>
<br>
By "freezing" parts of the network like this we can speed up the training of the model as we will only be updating a single layer, this is especially useful if our pretrained model is very big (note we still have to do a full forward pass of the model which might take a while)<br>
We can "freeze" the early layers of the model like this becasuse the ImageNet dataset that the model was trained on has similar images, and will have similar features to, the STL10 dataset we are using. Because of this the features in the images that the network would need to learn to detect, would be similar between datasets<br>
NOTE if our dataset is very different to the ImageNet dataset "freezing" parts of the model might not be effective <br>
Once we have trained our single layer for a while we can then unfreeze the rest of our model and train the whole thing for a few epochs to refine the model for our dataset

In [None]:
# Create a ResNet18 from the pytorch "models" module
# This is reasonably sized model at 18 layers deep
# ResNet Paper https://arxiv.org/pdf/1512.03385.pdf

# https://pytorch.org/docs/stable/torchvision/models.html#torchvision.models.resnet18
# res_net = models.resnet18(pretrained=True).to(device)
res_net = models.resnet18(pretrained=False).to(device)

In [None]:
# Uncomment this when ready
# Loop through all the learnable parameter objects (from the layers)
# for param in res_net.parameters():
# #     Set to True to unfreeze layers
#     param.requires_grad = False

Lets see the structure of this network

In [None]:
# view the network
res_net

In [None]:
# Lets see how many Parameter's our Model has!
num_params = 0
for param in res_net.parameters():
    num_params += param.flatten().shape[0]
print(
    f"This model has {num_params} (approximately {num_params // 1e6} Million) Parameters!"
)

The ImageNet challange dataset that the ResNet model was trained on has 1000 classes but the STL10 dataset only has 10 <br>
We can still use the pretrained model we just need to alter it a bit by simply replacing the last FC (linear) layer with a new one 

In [None]:
# Augment the model, by swapping out the last fc layer for a different one
# get the number of in_features into the last fc layer
num_ftrs = res_net.fc.in_features
# redefine the last fc layer with a linear layer with 10 ouputs, this layer's weights will be randomly initialised
res_net.fc = nn.Linear(num_ftrs, 10).to(device)

In [None]:
# pass image through network
out = res_net(images.to(device))
# check output
out.shape

# Set up the optimizer 

In [None]:
# Pass our network parameters to the optimiser set our lr as the learning_rate
optimizer = optim.Adam(res_net.parameters(), lr=learning_rate)

In [None]:
# Define a Cross Entropy Loss
loss_fun = nn.CrossEntropyLoss()

# Loading Checkpoints

In [None]:
# Create Save Path from save_dir and model_name, we will save and load our checkpoint here
save_path = os.path.join(save_dir, model_name + ".pt")

# Create the save directory if it does note exist
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)

# Load Checkpoint
if start_from_checkpoint:
    # Check if checkpoint exists
    if os.path.isfile(save_path):
        # load Checkpoint
        check_point = torch.load(save_path)
        # Checkpoint is saved as a python dictionary
        # https://www.w3schools.com/python/python_dictionaries.asp
        # here we unpack the dictionary to get our previous training states
        res_net.load_state_dict(check_point["model_state_dict"])
        optimizer.load_state_dict(check_point["optimizer_state_dict"])
        start_epoch = check_point["epoch"]
        best_valid_acc = check_point["valid_acc"]
        print("Checkpoint loaded, starting from epoch:", start_epoch)
    else:
        # Raise Error if it does not exist
        raise ValueError("Checkpoint Does not exist")
else:
    # If checkpoint does exist and Start_From_Checkpoint = False
    # Raise an error to prevent accidental overwriting
    if os.path.isfile(save_path):
        raise ValueError("Warning Checkpoint exists")
    else:
        print("Starting from scratch")

# Define the training process

In [None]:
# This function should perform a single training epoch using our training data
def train(net, device, loader, optimizer, loss_fun):

    # initialise counters
    epoch_loss = 0

    # Set Network in train mode
    net.train()

    for i, (x, y) in enumerate(loader):

        # load images and labels to device
        x = x.to(device)  # x is the image
        y = y.to(device)  # y is the corresponding label

        # Forward pass of image through network and get output
        fx = net(x)

        # Calculate loss using loss function
        loss = loss_fun(fx, y)

        # Zero Gradents
        optimizer.zero_grad()
        # Backpropagate Gradents
        loss.backward()
        # Do a single optimization step
        optimizer.step()

        # create the cumulative sum of the loss and acc
        epoch_loss += loss.item()
        # log the loss for plotting

    epoch_loss /= len(loader)

    # return the average loss from the epoch as well as the logger array
    return epoch_loss

# Define the testing process

In [None]:
# This function should perform a single evaluation epoch and will be passed our validation or evaluation/test data
# it WILL NOT be used to train out model
def evaluate(net, device, loader, loss_fun):

    epoch_loss = 0
    epoch_acc = 0

    # Set network in evaluation mode
    # Layers like Dropout will be disabled
    # Layers like Batchnorm will stop calculating running mean and standard deviation
    # and use current stored values
    net.eval()

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):

            # load images and labels to device
            x = x.to(device)
            y = y.to(device)

            # Forward pass of image through network
            fx = net(x)

            # Calculate loss using loss function
            loss = loss_fun(fx, y)

            # calculate the accuracy
            epoch_acc += (fx.argmax(1) == y).sum().item()

            # log the cumulative sum of the loss
            epoch_loss += loss.item()

    epoch_loss /= len(loader)
    epoch_acc /= len(loader.dataset)

    # return the average loss and acc from the epoch as well as the logger array
    return epoch_loss, epoch_acc

<h3> Learning rate schedular </h3>
It can be useful to start with a high learning rate and then decrease it after some time allowing the optimiser to "fine tune" the model<br>
There are many different ideas about how to change the learning rate over epochs, here we will create a simple "linear decay" schedular manually.<br>
Pytorch also has automatic Learning rate scheduling

[Learning rate scheduling](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)

In [None]:
# Create a function that will linearly decay the learning rate every epoch
def lr_linear_decay(epoch_max, epoch, lr):
    lr_adj = ((epoch_max - epoch) / epoch_max) * lr
    # update the learning rate parameter of the optimizer
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr_adj

# The training process <br>
You should record the traning and evaluation accuracy as well as the training time after every experiment! 

In [None]:
# Log the training and validation losses
training_loss_logger = []
validation_loss_logger = []
# Log the training and validation losses
training_acc_logger = []
validation_acc_logger = []

In [None]:
# This cell implements our training loop

# Record the start time
Start_time = time.time()

for epoch in range(start_epoch, num_epochs):
    # Implement the linear decay of the learning rate
    lr_linear_decay(num_epochs, epoch, learning_rate)

    # call the training function and pass training dataloader etc
    train_loss = train(res_net, device, train_loader, optimizer, loss_fun)

    # call the evaluate function and pass validation/training dataloader etc
    _, train_acc = evaluate(res_net, device, train_loader, loss_fun)
    valid_loss, valid_acc = evaluate(res_net, device, valid_loader, loss_fun)

    training_loss_logger.append(train_loss)
    validation_loss_logger.append(valid_loss)

    training_acc_logger.append(train_acc)
    validation_acc_logger.append(valid_acc)
    # If this model has the highest performace on the validation set
    # then save a checkpoint
    # {} define a dictionary, each entry of the dictionary is indexed with a string
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        if save_checkpoint:
            print("Saving Model")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": res_net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "train_acc": train_acc,
                    "valid_acc": valid_acc,
                },
                save_path,
            )

    print(
        f"| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |"
    )

End_time = time.time()

In [None]:
print("The highest validation accuracy was %.2f%%" % (best_valid_acc * 100))

In [None]:
print("Training time %.2f seconds" % (End_time - Start_time))

In [None]:
plt.figure(figsize=(10, 10))
train_x = np.linspace(0, num_epochs, len(training_loss_logger))
plt.plot(train_x, training_loss_logger, c="y")
valid_x = np.linspace(0, num_epochs, len(validation_loss_logger))
plt.plot(valid_x, validation_loss_logger, c="k")

plt.title("ResNet Loss")
plt.legend(["Training Loss", "Validation Loss"])

In [None]:
plt.figure(figsize=(10, 10))
train_x = np.linspace(0, num_epochs, len(training_acc_logger))
plt.plot(train_x, training_acc_logger, c="y")
valid_x = np.linspace(0, num_epochs, len(validation_acc_logger))
plt.plot(valid_x, validation_acc_logger, c="k")

plt.title("ResNet Acc")
plt.legend(["Training Acc", "Validation Acc"])

# Evaluate

In [None]:
# call the evaluate function and pass the evaluation/test dataloader etc
test_loss, test_acc = evaluate(res_net, device, test_loader, loss_fun)
print("Testing: | Loss %.2f | Accuracy %.2f%% |" % (test_loss, 100 * test_acc))