# Transfer Learning in PyTorch
Written by Calden Wloka for CS 153

This notebook draws heavily on the [official PyTorch transfer learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) by Sasank Chilamkurthy.

Some other extremely useful documentation you may find useful:
- [Saving and loading models](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)
    - You often need an object to persist across training environments or instances. This allows you to work around XSEDE's timeout limitations, or run multiple experiments at different times with the same model.
- [Torchvision object detection finetuning tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
    - This is a more advanced tutorial looking at fine tuning a model to a new dataset. In particular, it shows you how to set up a custom data loader.


The notebook is organized as follows:
1. Library Setup - Imports and device initialization
2. Data Handling - Setting up a data loader to feed training and validation data to our model
3. Model Handling - Setting up our model to import pre-trained weights and reconfigure for our new task
4. Training Setup - Create a training function to train our model
5. Execution - Perform transfer learning and investigate the results
6. Experiment Sandbox - Try a few extensions and alternative formulations

## Library Setup

We're using quite a large number of lirbaries for this one, including a whole bunch of `torch` libraries. As with other deep learning examples we've seen so far, we can execute our transfer learning with a CPU device, but it is much more efficient on a GPU device.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 10]

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

## Data Handling

The first step in most deep learning tasks, and particularly any tasks that involve training or finetuning of models, is setting up your data handling functionality. This typically means creating a set of tools to iteratively feed training data to your model, but you also usually want to include the option of validation and/or test data.

For this tutorial, we are going to train a classifier to differentiate between *ants* and *bees*. The dataset is available [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip).

If you download and extract the dataset, you will notice that it is rather small; there are about 120 training images each for ants and bees (along with 75 validation images). It would be hard to train a classifier from scratch on this amount of data, but transfer learning can leverage pre-training from prior data.

Furthermore, we will make use of some basic data augmentation for training in the form of randomized resizing and cropping (which introduces a form of spatial jitter), and randomized horizontal flipping. We also need to be sure to normalize our image to the same range of values that the network was initially trained over to maximally take advantage of the pre-trained features.

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

We'll now set up our dataloader to perform file handling. This step involves interacting with the file structure under which you have the images stored. For our purposes, we have the following folder organization:

         hymenoptera_data
            ______|______
            |           |
          train        val
         ___|___     ___|___
         |     |     |     |
       ants  bees   ants  bees
       
This allows the datahandler to know the type of data (training vs. validation) and class (ants vs. bees) based on its location, without needing to read from any annotation file.

In [None]:
data_dir = 'hymenoptera_data' # set this to your local path to the data root directory

# here we build our list of images in each data category
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

# The DataLoader class is a torch utility that is designed for easily iterating through a dataset
# during training or testing.
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

# These lines are not critical, but are helpful for understanding our data and visualizating our results.
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

Let's explore our data setup a little bit; make sure you understand what each part is doing.

In [None]:
print('Classes are:')
print(class_names)
print(' ')
print('The amount of data is:')
print(dataset_sizes)
print(' ')
print('The dataset dictionary is:')
print(image_datasets)

Another imporant part of data handling is to **_look at your data_**! Ideally, we want to look at our output *after* our dataloader has processed it, as that can help catch inappropriate augmentations or other transformation or data handling errors.

To look at data after the dataloader has processed it, we need to grab the tensors from the device and put them back in our regular workspace, and reshape them into the standard shape that we expect images to take. Our `tensor_show` function assumes the images have been pulled into our workspace, but takes care of the reshaping. We will use a `torchvision.utils` function called `make_grid` to turn a batch of images into one long image, and later we will use this function to visualize our predictions and explicitly send our images to the CPU.

In [None]:
def tensor_show(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

In [None]:
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

tensor_show(out, title=[class_names[x] for x in classes])

## Model Handling

Now that we have our data set up and ready for processing, we need a model to do that processing.

For this demonstration, we are going to work with the pretrained ResNet18 model. Note that the [original tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) demonstrates two different transfer learning protocols: finetuning the whole network (i.e. allowing weights throughout the network to change), and treating the network as a fixed feature extractor (i.e. only training the final classifier layer; sometimes this may be extended to multiple fully connected "readout" layers).

For this demo we will focus on the latter style, but both options can be useful.
We do this by setting the `requires_grad` parameter of the model feature layers to `False`, thereby preventing the gradient from being computed over them and leaving them open to updates by the training routine. Since newly constructed modules have `requires_grad=True` by default, when we declare our new output layer, it will be the only layer with a gradient calculation.

In [None]:
# We grab the whole pretrained model to start
model_tl = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
for param in model_tl.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
# We declare a new fully connected layer that has the same input dimensions as the original
# and now has the output dimensions of the number of classes in our target dataset
num_ftrs = model_tl.fc.in_features
model_tl.fc = nn.Linear(num_ftrs, len(class_names))

# this command sends the model to our device
model_tl = model_tl.to(device)

# here we set what type of loss we plan to use. Since this is a recognition task,
# cross entropy is a good loss function.
criterion = nn.CrossEntropyLoss()

# we also need to set up an optimizer. Our standard SGD works fine.
optimizer_tl = optim.SGD(model_tl.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_tl, step_size=7, gamma=0.1)


## Training Setup

So now that we have our dataset ready, and our model ready, it is time to define how we want that model to train using the data. This is typically done through an encapsulated training function which we will call `train_model`.

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # We are going to run for a set number of epochs, but that doesn't mean our final epoch is our best.
    # Keep track of which version of the model worked the best.
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # data logging
    losslog = [[],[]]
    acclog = [[],[]]
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # data logging
            if phase == 'train':
                losslog[0].append(epoch_loss)
                acclog[0].append(epoch_acc.to('cpu'))
            else:
                losslog[1].append(epoch_loss)
                acclog[1].append(epoch_acc.to('cpu'))

                
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, losslog, acclog

Just like looking at our data can be useful, it is also a very good idea to inspect your model predictions and not just rely on the validation accuracy. For that, we want a visualization function.

In [None]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                tensor_show(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

## Execution

Okay... after all that, it's time to run things! One of the nice things about deep learning is that, although the setup can be convoluted, once all your ducks are in a row you just kind of have to hit go.

In [None]:
model_tl, losslog, acclog = train_model(model_tl, criterion, optimizer_tl,
                         exp_lr_scheduler, num_epochs=25)

In [None]:
visualize_model(model_tl)

In [None]:
plt.subplot(1,2,1)
plt.plot(range(1,26),losslog[0], 'r-')
plt.plot(range(1,26),losslog[1], 'b-')
plt.title('Loss')
plt.legend(['training', 'validation'])

plt.subplot(1,2,2)
plt.plot(range(1,26),acclog[0], 'r-')
plt.plot(range(1,26),acclog[1], 'b-')
plt.title('Accuracy')
plt.legend(['training', 'validation'])

## Experiment Sandbox

So we saw that our initial setup isn't actually doing too much (the hymenoptera data is taken from ImageNet, so it's already data that our model is quite familiar with, so we seem to converge on decent performance _very_ fast). Depending on our application, maybe we need to try and squeeze out a bit more performance, in which case we can think of potential ways we might do that, perhaps by adjusting our data handling.

Alternatively, maybe we want to try and apply this to data that is a little more different, like [bears](https://www.kaggle.com/datasets/anirudhg15/bears-fastai-2021) or [cats and dogs](https://www.kaggle.com/datasets/alvarole/asirra-cats-vs-dogs-object-detection-dataset?resource=download).

Another option would be to explore transfer learning on a different network (e.g. VGG-16).

Finally, we could try manipulating our network architecture more than simply swapping out the fully connected layers; what happens if we instead try and learn from only a sub-portion of the feature layers from ResNet18?