# Transfer Learning And Data Augmentation

In this tutorial, we will learn how to apply fine-tuning and feature extraction with pre-trained models, and how to augment data when training dataset is small. 

This tutorial is adapted from [this](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) and [this](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html).

When the training dataset in the task is small, transfer learning usually performs better than training from scratch. There are two major transfer learning ways as follows:

__Fine-tuning__ vs. __Feature extraction__

- Similarity: to initialize the model with a pre-trained one's parameters for our new task

- Difference: to update the whole model vs. to only update the final predictive layer(s)


For more details about transfer learning, see [here](https://cs231n.github.io/transfer-learning/).

In [None]:
import os
from pathlib import Path

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"] = "3"       # eg. "0, 1, 2" for multiple

DATA_ROOT = '/data1/cifar/'
DEVICE = 'cuda:0'
BATCH_SIZE = 16
VAL_BATCH_SIZE = 128
TRAINSET_SIZE = 250
NUM_EPOCHS = 15
MODEL_NAME = "resnet"     # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]

In [None]:
class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # tuple(range(10))
num_classes = 10

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import time
import os
import copy

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms

plt.ion()
device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")

## Initialize A Model

- All of the torchvision models have been pretrained on the 1000-class Imagenet dataset.
- Since each model architecture is different, we must make custom adjustments for each model.
> Notice that inception_v3 requires the input size to be (299,299), whereas all of the other models expect (224,224).

In [None]:
models.resnet18()

In [None]:
models.alexnet()

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract=True, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        
        # your codes #
        
        

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        # Handle the auxilary net
        num_ftrs = model.AuxLogits.fc.in_features
        model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model, input_size

In [None]:

# Initialize a fine-tuning model
model_ft, input_size = initialize_model(MODEL_NAME, num_classes, feature_extract=False, use_pretrained=True)

# # Print the model we just instantiated
print(model_ft)

## Load Data

In Pytorch,

>All pre-trained models expect input images normalized in the same way... The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. 

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
#         transforms.Lambda(lambda t: t.expand(3, -1, -1))             # for grayscale dataset, eg. MNIST
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
#         transforms.Lambda(lambda t: t.expand(3, -1, -1))
    ]),
}

# Create training and validation datasets
image_datasets = {x: datasets.CIFAR10(DATA_ROOT, train=(x=='train'), transform=data_transforms[x])
                  for x in ['train', 'val']}

train_dataset, _ = torch.utils.data.random_split(image_datasets['train'], 
                                                 [TRAINSET_SIZE, len(image_datasets['train'])-TRAINSET_SIZE])

dataloaders_dict = {
    'train': torch.utils.data.DataLoader(train_dataset, 
                                         shuffle=True, batch_size=BATCH_SIZE, num_workers=4),
    
    'val': torch.utils.data.DataLoader(image_datasets['val'], shuffle=True,
                                       batch_size=VAL_BATCH_SIZE, num_workers=4)
}

dataset_sizes = {
    'train': TRAINSET_SIZE,
    'val': len(image_datasets['val'])
}


In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.axis('off')
    if title is not None:
        plt.title(title)

# Get a batch of training data
inputs, classes = next(iter(dataloaders_dict['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

plt.figure(figsize=(15, 5))
imshow(out, title=[class_names[x] for x in classes])

## Train A Model

In [None]:
def train(model, dataloaders, criterion, optimizer, scheduler=None, num_epochs=25, is_inception=False):
    since = time.time()
    
    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                if scheduler:
                    scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        
                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += float(loss) * inputs.size(0)
                running_corrects += float(torch.sum(preds == labels.data))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

## Update Necessary Parameters

In [None]:
# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
def get_params_to_update(model, feature_extract, print_params=True):
    params_to_update = model.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                if print_params: print("\t", name)
    else:
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                if print_params: print("\t", name)
                
    return params_to_update

In [None]:
model_ft = model_ft.to(device)
params_to_update_ft = get_params_to_update(model_ft, feature_extract=False)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update_ft, lr=0.001, momentum=0.9)
# optimizer_ft = optim.Adam(params_to_update_ft, lr=0.001)
# Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler_ft = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

criterion = nn.CrossEntropyLoss()

In [None]:
model_ft, hist_ft = train(model_ft, dataloaders_dict, criterion, optimizer_ft, # exp_lr_scheduler_ft,
                          num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))

In [None]:
# Initialize a feature extraction model

model_conv, _ = initialize_model(MODEL_NAME, num_classes, feature_extract=True, use_pretrained=True)
model_conv = model_conv.to(device)

params_to_update_conv = get_params_to_update(model_conv, feature_extract=True)

# Observe that all parameters are being optimized
optimizer_conv = optim.SGD(params_to_update_conv, lr=0.001, momentum=0.9)
# optimizer_conv = optim.Adam(params_to_update_conv)

# Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler_conv = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

In [None]:
model_conv, hist_conv = train(model_conv, dataloaders_dict, criterion, optimizer_conv, # exp_lr_scheduler_conv,
                       num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))

## Comparison With Model Trained From Scratch

In [None]:
# Initialize the non-pretrained version of the model
model_scratch, _ = initialize_model(MODEL_NAME, num_classes, feature_extract=False, use_pretrained=False)
model_scratch = model_scratch.to(device)
optimizer_scratch = optim.SGD(model_scratch.parameters(), lr=0.001, momentum=0.9)
# optimizer_scratch = optim.Adam(model_scratch.parameters())

model_scratch, hist_scratch = train(model_scratch, dataloaders_dict, criterion, optimizer_scratch, 
                                    num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))



In [None]:
# Plot the training curves of validation accuracy vs. number
#  of training epochs for the transfer learning method and
#  the model trained from scratch
plt.title("Validation Accuracy vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1, NUM_EPOCHS+1), hist_ft, label="Fine-tuning")
plt.plot(range(1, NUM_EPOCHS+1), hist_conv, label="Feature Extraction")
plt.plot(range(1, NUM_EPOCHS+1), hist_scratch, label="Scratch")
plt.ylim((0,1.))
plt.xticks(np.arange(1, NUM_EPOCHS+1, 1.0))
plt.legend()
plt.show()

## Data Augmentation

We can augment the training data by using random transoformers in Pytorch.

In [None]:
train_dataset.dataset.transform

In [None]:
da_transform = transforms.Compose([
        transforms.RandomRotation(3),
        transforms.RandomResizedCrop(input_size, scale=(0.9, 1), ratio=(4/5, 5/4)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

In [None]:
train_dataset.dataset.transform = da_transform
train_dataset.dataset.transform

In [None]:
dataloaders_dict['train'] = torch.utils.data.DataLoader(train_dataset, shuffle=True, 
                                                        batch_size=BATCH_SIZE, num_workers=4)


inputs, classes = next(iter(dataloaders_dict['train']))
out = torchvision.utils.make_grid(inputs)

plt.figure(figsize=(15, 5))
imshow(out, title=[class_names[x] for x in classes])

In [None]:
model_ft, input_size = initialize_model(MODEL_NAME, num_classes, feature_extract=False, use_pretrained=True)
model_ft = model_ft.to(device)
params_to_update_ft = get_params_to_update(model_ft, feature_extract=False, print_params=False)
optimizer_ft = optim.SGD(params_to_update_ft, lr=0.001, momentum=0.9)
model_ft, hist_ft = train(model_ft, dataloaders_dict, criterion, optimizer_ft,
                          num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))

model_conv, _ = initialize_model(MODEL_NAME, num_classes, feature_extract=True, use_pretrained=True)
model_conv = model_conv.to(device)
params_to_update_conv = get_params_to_update(model_conv, feature_extract=True, print_params=False)
optimizer_conv = optim.SGD(params_to_update_conv, lr=0.001, momentum=0.9)
model_conv, hist_conv = train(model_conv, dataloaders_dict, criterion, optimizer_conv,
                       num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))

model_scratch, _ = initialize_model(MODEL_NAME, num_classes, feature_extract=False, use_pretrained=False)
model_scratch = model_scratch.to(device)
optimizer_scratch = optim.SGD(model_scratch.parameters(), lr=0.001, momentum=0.9)
model_scratch, hist_scratch = train(model_scratch, dataloaders_dict, criterion, optimizer_scratch, 
                                    num_epochs=NUM_EPOCHS, is_inception=(MODEL_NAME=="inception"))

plt.title("Validation Accuracy vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1, NUM_EPOCHS+1), hist_ft, label="Fine-tuning")
plt.plot(range(1, NUM_EPOCHS+1), hist_conv, label="Feature Extraction")
plt.plot(range(1, NUM_EPOCHS+1), hist_scratch, label="Scratch")
plt.ylim((0,1.))
plt.xticks(np.arange(1, NUM_EPOCHS+1, 1.0))
plt.legend()
plt.show()

## Visualize The Predictive Results

In [None]:
def visualize(model, num_images=16, column=4, figsize=(8, 10)):
    was_training = model.training
    column = int(column)
    model.eval()
    images_so_far = 0
    plt.figure(figsize=figsize)

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders_dict['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot(num_images//column, column, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
    
    model.train(mode=was_training)

In [None]:
visualize(model_ft)

In [None]:
visualize(model_conv)

In [None]:
visualize(model_scratch)