# Transfer Learning in PyTorch
Written by Calden Wloka for CS 153

This notebook draws heavily on the [official PyTorch transfer learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) by Sasank Chilamkurthy.

Some other extremely useful documentation you may find useful:
- [Saving and loading models](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)
    - You often need an object to persist across training environments or instances. This allows you to work around XSEDE's timeout limitations, or run multiple experiments at different times with the same model.
- [Torchvision object detection finetuning tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
    - This is a more advanced tutorial looking at fine tuning a model to a new dataset. In particular, it shows you how to set up a custom data loader.


# Imports:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

from PIL import Image
import itertools


from helpers import show_img

from prep_data import process_img

from preprocess import generate_normalizer

In [None]:
plt.rcParams['figure.figsize'] = [10, 5]

# Custom Parameters:

In [None]:
data_dir = 'data/lbp'
# data_dir = 'data/augmented'

# load_path = 'Models/pollen-model.pt'
load_path = None

# Load Data/Train:

In [None]:
print(torch.__version__)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

This tutorial trains a classifier to differentiate between *ants* and *bees*. The dataset is available [here](https://download.pytorch.org/tutorial/hymenoptera_data.zip).

If you download and extract the dataset, you will notice that it is rather small; there are about 120 training images each for ants and bees (along with 75 validation images). It would be hard to train a classifier from scratch on this amount of data, but transfer learning can leverage pre-training from prior data.

Furthermore, we will make use of some basic data augmentation for training.

In [None]:
if not os.path.exists('Models/'): os.mkdir('Models/')

if load_path is not None:
  checkpoint = torch.load(load_path, map_location=device)

  data_transforms = checkpoint['data_transforms']
  dataloaders = checkpoint['dataloaders']
  class_names = dataloaders['train'].dataset.classes

  dataset_sizes = {x: len(dataloaders[x].dataset.imgs) for x in ['train', 'val']}
else:
  data_transforms = {
    'train': transforms.Compose([
      transforms.RandomResizedCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
  }


  dataset = datasets.ImageFolder(data_dir)

  train_size = int(0.8 * len(dataset))
  val_size = len(dataset) - train_size

  image_datasets = {}

  train_data, val_data = torch.utils.data.random_split(dataset, [train_size, val_size])
  image_datasets['train'], image_datasets['val'] = train_data.dataset, val_data.dataset

  image_datasets['train'].transform = data_transforms['train']
  image_datasets['val'].transform = data_transforms['val']


  dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                        shuffle=True, num_workers=4)
        for x in ['train', 'val']}
  dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}


  class_names = image_datasets['train'].classes

  print('The amount of data is:')
  print(dataset_sizes)
  print(' ')
  print('The dataset dictionary is:')
  print(image_datasets)

Let's explore our data setup a little bit; make sure you understand what each part is doing.

In [None]:
print('Classes are:')
print(class_names)

Another imporant part of data handling is to *look at your data*! Ideally, we want to look at our output *after* our dataloader has processed it, as that can help catch inappropriate augmentations or other transformation or data handling errors.

To look at data after the dataloader has processed it, we need to grab the tensors from the device and put them back in our regular workspace, and reshape them into the standard shape that we expect images to take. Our `tensor_show` function assumes the images have been pulled into our workspace, but takes care of the reshaping. We will use a `torchvision.utils` function called `make_grid` to turn a batch of images into one long image, and later we will use this function to visualize our predictions and explicitly send our images to the CPU.

In [None]:
def tensor2numpy_img(inp):
  inp = inp.numpy().transpose((1, 2, 0)) #usually start with batch size (different order)
  # already transformed them, so need to undo to get them ready to displace
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  inp = std * inp + mean
  # clip them back
  inp = np.clip(inp, 0, 1)
  return inp


def tensor_show(inp, title=None):
  """Imshow for Tensor."""
  inp = tensor2numpy_img(inp)
  plt.imshow(inp)
  if title is not None:
    plt.title(title)

In [None]:
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

tensor_show(out, title=[class_names[x] for x in classes])

Now that we have data handling set up, we want to set up our training routines. 

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, save_path='Models/pollen-model.pt'):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_acc = 0.0

    # each epoch is a full run-through of training data
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode (compute gradient over any of the weights you told the network it can update)
            else:
                model.eval()   # Set model to evaluation mode (so doesn't update weights, just gives prediction)

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                # optimizer tells you how you are updating the weights
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                # scheduler updates your learning rate
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            # when you are in validation phase, look at accuracy over particular epoch, and compare with best accuracy thus far
            if phase == 'val' and epoch_acc > best_acc:
                best_epoch = epoch
                best_acc = epoch_acc
                # save weights (model.state_dict()) as best model weights
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    # XSEDE has timeout of 12 hrs or something
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f} (epoch #{best_epoch})')

    # load best model weights
    model.load_state_dict(best_model_wts)

    if save_path is not None:
        torch.save({
            'model_state_dict': model.state_dict(),
            'data_transforms': data_transforms,
            'dataloaders': dataloaders,
        }, save_path)

    return model

Just like looking at our data can be useful, it is also a very good idea to inspect your model predictions and not just rely on the validation accuracy. For that, we want a visualization function.

In [None]:
def visualize_model(model, num_images=6):
  was_training = model.training
  model.eval()
  images_so_far = 0
  fig = plt.figure()

  with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloaders['val']):
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = model(inputs)
      _, preds = torch.max(outputs, 1)

      for j in range(inputs.size()[0]):
        images_so_far += 1
        ax = plt.subplot(num_images//2, 2, images_so_far)
        ax.axis('off')
        ax.set_title(f'{class_names[labels[j]]}\npredicted: {class_names[preds[j]]}')
        tensor_show(inputs.cpu().data[j])

        if images_so_far == num_images:
          model.train(mode=was_training)
          return
    model.train(mode=was_training)

Now that we have training routines defined, we can perform transfer learning! Note that the [original tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) demonstrates two different transfer learning protocols: finetuning the whole network (i.e. allowing weights throughout the network to change), and treating the network as a fixed feature extractor (i.e. only training the final classifier layer; sometimes this may be extended to multiple fully connected "readout" layers).

For this demo we will focus on the latter style, but both options can be useful.
We do this by setting the `requires_grad` parameter of the model feature layers to `False`, thereby preventing the gradient from being computed over them and leaving them open to updates by the training routine. Since newly constructed modules have `requires_grad=True` by default, when we declare our new output layer, it will be the only layer with a gradient calculation.

In [None]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    # tells the model not to change this part of network (don't compute gradient over weights)
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
# number of features activates in network, number classes
model_conv.fc = nn.Linear(num_ftrs, len(class_names))

# load model onto device
model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# SGD is the most classic optimizer to use
# tell optimizer learner rate (high big steps are to take)
# momentum (default 0) additional hyperparameter that tries to make sure not too spikey (avg with weight to figure out how to update model when looking at errors)
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
# step_size
# gamma is how much you decay by
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)


In [None]:
if load_path is not None:
  model_conv.load_state_dict(checkpoint['model_state_dict'])
  model_conv.eval()
  model_conv.cuda()
else:
  model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

# Tools for Visualizing Results/Interpretation:

In [None]:
# print(list(filter(lambda method: not '__' in method, dir(model_conv.state_dict()))))

visualize_model(model_conv)

In [None]:
def load_img_data(img_path):
  path, filename = os.path.split(img_path)
  label = os.path.split(path)[-1]

  image = Image.open(img_path)

  return label, image

In [None]:
def get_preds(phase='val', with_input=True, with_targs=True, with_preds=True, with_loss=True, with_data_idx=False):
  """ returns inputs, targs, preds, losses """
  res = []
  input_imgs, targets, predictions, losses, data_idxs = [], [], [], [], []

  with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloaders[phase]):
      inputs = inputs.to(device)
      if with_input: input_imgs.extend([tensor2numpy_img(im) for im in inputs.cpu()])

      labels = labels.to(device)
      if with_targs: targets.append(labels.cpu().numpy())

      batch_size = len(labels)

      if with_data_idx: data_idxs.extend([i+idx for idx in range(batch_size)])

      outputs = model_conv(inputs)
      if with_loss:
        loss = [round(criterion(torch.unsqueeze(outputs[idx], dim=0), torch.unsqueeze(labels[idx], dim=0)).item(), 2) for idx in range(len(labels))]
        losses.extend(loss)

      _, preds = torch.max(outputs, 1)
      if with_preds: predictions.append(preds.cpu().numpy())

  if with_input: res.append(np.array(input_imgs))
  if with_targs: res.append(np.concatenate(targets))
  if with_preds: res.append(np.concatenate(predictions))
  if with_loss: res.append(np.array(losses))
  if with_data_idx: res.append(np.array(data_idxs))

  if len(res) > 1: return tuple(res)
  else: return res[0]


def predict(image, with_class=False, with_output=False):
  res = []

  transformed = data_transforms['val'](image).float()
  transformed = transformed.unsqueeze_(0)

  input_img = transformed.to(device)
  output = model_conv(input_img)

  idx = output.data.cpu().numpy().argmax()
  res.append(idx)
  if with_class: res.append(class_names[idx])
  if with_output: res.append(output)
  
  if len(res) > 1: return tuple(res)
  else: return res[0]


def get_loss(image, label):
  pred, output = predict(image, with_output=True)

  loss = round(criterion(output, torch.tensor([label], device=device)).item(), 2)

  return loss

In [None]:
# inspired by fastai (https://docs.fast.ai/interpret.html)
class Interpretation():
  def __init__(self, model):
    print('initialized')

    self.model = model
    self.class_names = class_names
    self.class_ct = len(self.class_names)

    self.cm = None

  
  def get_class_idx(self, class_name):
    return self.class_names.index(class_name)


  def get_class_data(self, class_idx, phase='val', with_data_idx=False):
    if with_data_idx: y_data, y_targs, y_preds, y_losses, y_idxs = get_preds(phase=phase, with_input=True, with_targs=True, with_preds=True, with_loss=True, with_data_idx=True)
    else: y_data, y_targs, y_preds, y_losses = get_preds(phase=phase, with_input=True, with_targs=True, with_preds=True, with_loss=True)

    class_targs_idxs = np.argwhere(y_targs == class_idx)

    res = [y_data[class_targs_idxs].squeeze(), y_targs[class_targs_idxs].squeeze(), y_preds[class_targs_idxs].squeeze(), y_losses[class_targs_idxs].squeeze()]
    if with_data_idx: res.append(y_idxs[class_targs_idxs].squeeze())

    return tuple(res)
  

  def get_top_losses(self, ct=None, phase='val', class_idx=None, with_data_idx=False, loss_thresh=None):
    if class_idx is None:
      info = get_preds(phase=phase, with_input=True, with_targs=True, with_preds=True, with_loss=True, with_data_idx=with_data_idx)

      if with_data_idx: y_data, y_targs, y_preds, y_losses, y_idxs = info
      else: y_data, y_targs, y_preds, y_losses = info
    else:
      info = self.get_class_data(class_idx, phase=phase, with_data_idx=with_data_idx)

      if with_data_idx: y_data, y_targs, y_preds, y_losses, y_idxs = info
      else: y_data, y_targs, y_preds, y_losses = info
    
    dists = abs(y_preds - y_targs)

    errors = (y_preds - y_targs != 0)

    err_ct = len(np.argwhere(errors==True))

    error_idxs = np.argwhere(errors==True)

    sorted_dists = np.argsort(dists, axis=0)

    print('number of errors:', len(np.argwhere(errors==True)), 'correct:', len(np.argwhere(errors==False)))

    sorted_losses = np.argsort(y_losses, axis=0)

    if ct is None or err_ct < ct: ct = err_ct

    if loss_thresh is None: top_errors = sorted_losses[-ct:]
    else: top_errors = np.argwhere(y_losses > loss_thresh)

    res = [y_data[top_errors].squeeze(), y_targs[top_errors].squeeze(), y_preds[top_errors].squeeze(), y_losses[top_errors].squeeze()]
    if with_data_idx: res.append(y_idxs[top_errors].squeeze())

    return tuple(res)


  def plot_top_losses(self, ct=None, phase='val', class_idx=None, loss_thresh=None):
    y_data, y_targs, y_preds, y_losses = self.get_top_losses(ct, phase=phase, class_idx=class_idx, loss_thresh=loss_thresh)

    img_targs = [self.class_names[targ] for targ in y_targs]
    img_preds = [self.class_names[pred] for pred in y_preds]

    img_titles = [f'Label: {targ}\nPrediction: {pred}\n(loss: {loss})' for targ, pred, loss in zip(img_targs, img_preds, y_losses)]

    show_img(list(y_data), title=img_titles, suptitle=(None if class_idx is None else class_names[class_idx]), targ=img_targs, pred=img_preds)


  def confusion_matrix(self, save=False):
    cm = torch.zeros(self.class_ct, self.class_ct)

    with torch.no_grad():
      for i, (inputs, classes) in enumerate(dataloaders['val']):
        inputs = inputs.to(device)
        classes = classes.to(device)
        outputs = model_conv(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(classes.view(-1), preds.view(-1)):
          cm[t.long(), p.long()] += 1

    #option to store confusion matrix with so don't have to re-do it every time we plot
    if save: self.cm = cm

    return cm


  def plot_confusion_matrix(self, plot_txt=True, cmap='Blues'):
    if self.cm is None: cm = self.confusion_matrix()
    else: cm = self.cm

    fig = plt.figure(figsize=(16, 10), dpi=100) # (12, 8)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title('Confusion Matrix')
    tick_marks = np.arange(self.class_ct)
    plt.xticks(tick_marks, self.class_names, rotation=90)
    plt.yticks(tick_marks, self.class_names, rotation=0)

    if plot_txt:
      thresh = cm.max() / 2.
      for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        coeff = f'{int(cm[i, j])}'
        plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                  if cm[i, j] > thresh else "black")

        ax = fig.gca()
        ax.set_ylim(self.class_ct-.5,-.5)

        # plt.tight_layout()
        plt.ylabel('Actual')
        plt.xlabel('Predicted')
        plt.grid(False)


  def per_class_accuracy(self, sort=True):
    cm = self.confusion_matrix()

    class_accuracy = cm.diag()/cm.sum(1)

    if sort:
      class_accuracy = class_accuracy.cpu().numpy()
      acc_idxs = np.argsort(class_accuracy)[::-1]

      acc_sorted = class_accuracy[acc_idxs]
      classes_sorted = np.array(self.class_names)[acc_idxs]

      for class_name, acc in zip(classes_sorted, acc_sorted):
        print(f'class: {class_name} / accuracy: {round(acc.item(), 4)}')

      return acc_sorted
    else:
      for i, acc in enumerate(class_accuracy):
        print(f'class: {self.class_names[i]} / accuracy: {round(acc.item(), 4)}')

      return class_accuracy

  
  def least_accurate(self, ct=1):
    class_accuracy = self.per_class_accuracy(sort=False).cpu().numpy()

    acc_idxs = np.argsort(class_accuracy)

    acc_sorted = class_accuracy[acc_idxs]
    classes_sorted = np.array(self.class_names)[acc_idxs]

    for class_name, acc in zip(classes_sorted, acc_sorted):
      print(f'class: {class_name} / accuracy: {round(acc.item(), 4)}')

    return classes_sorted[:ct], acc_sorted[:ct]

In [None]:
torch.cuda.empty_cache()

interp = Interpretation(model_conv)


In [None]:
# get the confusion matrix

cm = interp.confusion_matrix(save=True)

In [None]:
# plot confusion matrix

interp.plot_confusion_matrix()

In [None]:
# per-class accuracy, sorted from highest to lowest score

per_class_acc = interp.per_class_accuracy(sort=True)

new_score = per_class_acc[:-2].sum() / (len(class_names)-2)


new_score

In [None]:
# visualize predictions with highest losses

interp.plot_top_losses(ct=10)