# <font style="color:blue">Transfer Learning and Fine-tuning </font>
In this chapter, we will learn how to fine-tune a pre-trained model for a different task than it was originally trained for.

When we train a network from scratch, we encounter the following two limitations :

- Huge data required - Since the network has millions of parameters, to get an optimal set of parameters, we need to have a lot of data.
- Huge computing power required - Even if we have a lot of data, training generally requires multiple iterations and it takes a toll on the computing resources.

The pre-trained models are trained on very large scale image classification problems. The convolutional layers act as feature extractor and the fully connected layers act as Classifiers.

Since these models are very large and have seen a huge number of images, they tend to learn very good, discriminative features. We can either use the convolutional layers merely as a feature extractor and change the last layer according to our problem or we can tweak the already trained convolutional layers to suit our problem at hand. The former approach is known as **Transfer Learning** and the latter as **Fine-tuning**.

The task of fine-tuning a network is to tweak the parameters of an already trained network so that it adapts to the new task at hand. The initial layers of a network learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. Thus, for fine-tuning, we want to keep the initial layers intact ( or freeze them ) and retrain the later layers for our task.

Thus, fine-tuning avoids both the limitations discussed above.

The amount of data required for training is not much because of two reasons. 
- First, we are not training the entire network. Second, the part that is being trained is not trained from scratch.
- Since the parameters that need to be updated is less, the amount of time needed will also be less.

As a rule of thumb, when we have a small training set and our problem is similar to the task for which the pre-trained models were trained, we can use transfer learning. If we have enough data, we can try and tweak the convolutional layers so that they learn more robust features relevant to our problem. You can get a detailed overview of Fine-tuning and transfer learning [here](http://cs231n.github.io/transfer-learning/).

In [1]:
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt  # one of the best graphics library for python
plt.style.use('ggplot')

In [3]:
import os
import time

from typing import Iterable
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets, transforms, models

from torch.optim import lr_scheduler

from torch.utils.tensorboard import SummaryWriter

# <font style="color:blue">Launch Tensorboard </font>

In [4]:
# %load_ext tensorboard
%reload_ext tensorboard

%tensorboard --logdir=log_resnet18/transfer_learning

Reusing TensorBoard on port 6006 (pid 10259), started 1:00:53 ago. (Use '!kill 10259' to kill it.)

# <font style="color:green">Data Processing Utils</font>

In [5]:
def image_preprocess_transforms():
    
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor()
        ])
    
    return preprocess

In [6]:
def image_common_transforms(mean, std):
    preprocess = image_preprocess_transforms()
    
    common_transforms = transforms.Compose([
        preprocess,
        transforms.Normalize(mean, std)
    ])
    
    return common_transforms
    

In [7]:
def data_augmentation_preprocess(mean, std):
    
    initail_transoform = transforms.RandomChoice([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(90)
        ])
    
    common_transforms = image_common_transforms(mean, std)
                
    aug_transforms = transforms.Compose([
        initail_transoform,
        transforms.RandomGrayscale(p=0.1),
        common_transforms
        ])
    
    return aug_transforms
    

In [8]:
def data_loader(data_root, transform, batch_size=16, shuffle=False, num_workers=2):
    dataset = datasets.ImageFolder(root=data_root, transform=transform)
    
    loader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=shuffle)
    
    return loader

In [9]:
def get_mean_std():
    
    mean = [0.485, 0.456, 0.406] 
    std = [0.229, 0.224, 0.225]
    
    return mean, std

In [10]:
def get_data(batch_size, data_root, tb_writer, num_workers=4, data_augmentation=True):
    
    train_data_path = os.path.join(data_root, 'training')
       
    mean, std = get_mean_std()
    
    common_transforms = image_common_transforms(mean, std)
        
   
    # if data_augmentation is true 
    # data augmentation implementation
    if data_augmentation:    
        train_transforms = data_augmentation_preprocess(mean, std)
    # else do common transforms
    else:
        train_transforms = common_transforms
        
        
    # train dataloader
    
    train_loader = data_loader(train_data_path, 
                               train_transforms, 
                               batch_size=batch_size, 
                               shuffle=True, 
                               num_workers=num_workers)
    
    # test dataloader
    
    test_data_path = os.path.join(data_root, 'validation')
    
    test_loader = data_loader(test_data_path, 
                              train_transforms, 
                              batch_size=batch_size, 
                              shuffle=False, 
                              num_workers=num_workers)
    
    # test dataloader
    
    testdata = datasets.ImageFolder(root=test_data_path, transform=common_transforms)
    
    # add enbeding / projector
    
    add_data_embedings(testdata, tb_writer, n=100)
    
    return train_loader, test_loader

# <font style="color:blue">Add data embeddings / Projector</font>

In [11]:
fashion_mnist_classes = ['cat', 'dog', 'panda']


In [12]:
def get_random_inputs_labels(inputs, targets, n=100):
    """
    get random inputs and labels
    """

    assert len(inputs) == len(targets)

    rand_indices = torch.randperm(len(targets))
    
    data = inputs[rand_indices][:n]
    
    labels = targets[rand_indices][:n]
    
    class_labels = [fashion_mnist_classes[lab] for lab in labels]
    
    return data, class_labels

In [13]:
def add_data_embedings(dataset, tb_writer, n=100):
    """
    Add a few inputs and labels to tensorboard. 
    """
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=n, num_workers=4, shuffle=True)
    
    images, labels = next(iter(dataloader))
    
    tb_writer.add_embedding(mat = images.view(-1, 3 * 224 * 224), 
                            metadata=labels, 
                            label_img=images)
    
    return

## <font style="color:green">System Configuration</font>

In [14]:
@dataclass
class SystemConfiguration:
    '''
    Describes the common system setting needed for reproducible training
    '''
    seed: int = 21  # seed number to set the state of all random number generators
    cudnn_benchmark_enabled: bool = True  # enable CuDNN benchmark for the sake of performance
    cudnn_deterministic: bool = True  # make cudnn deterministic (reproducible training)

## <font style="color:green">Training Configuration</font>

In [15]:
@dataclass
class TrainingConfiguration:
    '''
    Describes configuration of the training process
    '''
    batch_size: int = 32  
    epochs_count: int = 50 
    init_learning_rate: float = 0.001  # initial learning rate for lr scheduler
    decay_rate: float = 0.1  
    log_interval: int = 500  
    test_interval: int = 1  
    data_root: str = "./cat-dog-panda" 
    num_workers: int = 10  
    device: str = 'cuda'  
    


## <font style="color:green">System Setup</font>

In [16]:
def setup_system(system_config: SystemConfiguration) -> None:
    torch.manual_seed(system_config.seed)
    if torch.cuda.is_available():
        torch.backends.cudnn_benchmark_enabled = system_config.cudnn_benchmark_enabled
        torch.backends.cudnn.deterministic = system_config.cudnn_deterministic

In [17]:
def prediction(model, device, batch_input, max_prob=True):
    """
    get prediction for batch inputs
    """
    
    # send model to cpu/cuda according to your system configuration
    model.to(device)
    
    # it is important to do model.eval() before prediction
    model.eval()

    data = batch_input.to(device)

    output = model(data)

    # Score to probability using softmax
    prob = F.softmax(output, dim=1)
    
    if max_prob:
        # get the max probability
        pred_prob = prob.data.max(dim=1)[0]
    else:
        pred_prob = prob.data
    
    # get the index of the max probability
    pred_index = prob.data.max(dim=1)[1]
    
    return pred_index.cpu().numpy(), pred_prob.cpu().numpy()

In [18]:
def get_target_and_prob(model, dataloader, device):
    """
    get targets and prediction probabilities
    """
    
    pred_prob = []
    targets = []
    
    for _, (data, target) in enumerate(dataloader):
        
        _, prob = prediction(model, device, data, max_prob=False)
        
        pred_prob.append(prob)
        
        target = target.numpy()
        targets.append(target)
        
    targets = np.concatenate(targets)
    targets = targets.astype(int)
    pred_prob = np.concatenate(pred_prob, axis=0)
    
    return targets, pred_prob
    
    

# <font style="color:blue">Add PR Curves to Tensorboard</font>

In [19]:
def add_pr_curves_to_tensorboard(model, dataloader, device, tb_writer, epoch, num_classes=3):
    """
    Add precession and recall curve to tensorboard.
    """
    
    targets, pred_prob = get_target_and_prob(model, dataloader, device)
    
    for cls_idx in range(num_classes):
        binary_target = targets == cls_idx
        true_prediction_prob = pred_prob[:, cls_idx]
        
        tb_writer.add_pr_curve(fashion_mnist_classes[cls_idx], 
                               binary_target, 
                               true_prediction_prob, 
                               global_step=epoch)
        
    return
    

# <font style="color:blue">Push Wrong Prediction to Tensorboard</font>

In [20]:
def add_wrong_prediction_to_tensorboard(model, dataloader, device, tb_writer, 
                                        epoch, tag='Wrong_Predections', max_images='all'):
    """
    Add wrong predicted images to tensorboard.
    """
    #number of images in one row
    num_images_per_row = 8
    im_scale = 3
    
    plot_images = []
    wrong_labels = []
    pred_prob = []
    right_label = []
    
    mean, std = get_mean_std()
    
    for _, (data, target) in enumerate(dataloader):
        
        
        images = data.numpy()
        pred, prob = prediction(model, device, data)
        target = target.numpy()
        indices = pred.astype(int) != target.astype(int)
        
        plot_images.append(images[indices])
        wrong_labels.append(pred[indices])
        pred_prob.append(prob[indices])
        right_label.append(target[indices])
        
    plot_images = np.concatenate(plot_images, axis=0).squeeze()
    plot_images = (np.moveaxis(plot_images, 1, -1) * std) + mean
    print('plot_images.shape: {}'.format(plot_images.shape))
    print(plot_images.min())
    print(plot_images.max())
    wrong_labels = np.concatenate(wrong_labels)
    wrong_labels = wrong_labels.astype(int)
    right_label = np.concatenate(right_label)
    right_label = right_label.astype(int)
    pred_prob = np.concatenate(pred_prob)
    
    
    if max_images == 'all':
        num_images = len(images)
    else:
        num_images = min(len(plot_images), max_images)
        
    fig_width = num_images_per_row * im_scale
    
    if num_images % num_images_per_row == 0:
        num_row = num_images/num_images_per_row
    else:
        num_row = int(num_images/num_images_per_row) + 1
        
    fig_height = num_row * im_scale
        
    plt.style.use('default')
    plt.rcParams["figure.figsize"] = (fig_width, fig_height)
    fig = plt.figure()
    
    for i in range(num_images):
        plt.subplot(num_row, num_images_per_row, i+1, xticks=[], yticks=[])
        plt.imshow(plot_images[i])
        plt.gca().set_title('{0}({1:.2}), {2}'.format(fashion_mnist_classes[wrong_labels[i]], 
                                                          pred_prob[i], 
                                                          fashion_mnist_classes[right_label[i]]))
        
    tb_writer.add_figure(tag, fig, global_step=epoch)
    
    return


## <font style="color:green">Training Function</font>

We are familiar with the training pipeline used in PyTorch.

In [21]:
def train(
    train_config: TrainingConfiguration, model: nn.Module, optimizer: torch.optim.Optimizer,
    train_loader: torch.utils.data.DataLoader, epoch_idx: int, tb_writer: SummaryWriter
) -> None:
    
    # change model in training mood
    model.train()
    
    # to get batch loss
    batch_loss = np.array([])
    
    # to get batch accuracy
    batch_acc = np.array([])
        
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # clone target
        indx_target = target.clone()
        # send data to device (its is mandatory if GPU has to be used)
        data = data.to(train_config.device)
        # send target to device
        target = target.to(train_config.device)

        # reset parameters gradient to zero
        optimizer.zero_grad()
        
        # forward pass to the model
        output = model(data)
        
        # cross entropy loss
        loss = F.cross_entropy(output, target)
        
        # find gradients w.r.t training parameters
        loss.backward()
        # Update parameters using gardients
        optimizer.step()
        
        batch_loss = np.append(batch_loss, [loss.item()])
        
        # Score to probability using softmax
        prob = F.softmax(output, dim=1)
            
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1]  
                        
        # correct prediction
        correct = pred.cpu().eq(indx_target).sum()
            
        # accuracy
        acc = float(correct) / float(len(data))
        
        batch_acc = np.append(batch_acc, [acc])

        if batch_idx % train_config.log_interval == 0 and batch_idx > 0:
            
            total_batch = epoch_idx * len(train_loader.dataset)/train_config.batch_size + batch_idx
            tb_writer.add_scalar('Loss/train-batch', loss.item(), total_batch)
            tb_writer.add_scalar('Accuracy/train-batch', acc, total_batch)
            
    epoch_loss = batch_loss.mean()
    epoch_acc = batch_acc.mean()
    return epoch_loss, epoch_acc

## <font style="color:green">Validation Function</font>

In [22]:
def validate(
    train_config: TrainingConfiguration,
    model: nn.Module,
    test_loader: torch.utils.data.DataLoader
) -> float:
    # 
    model.eval()
    test_loss = 0
    count_corect_predictions = 0
    for data, target in test_loader:
        indx_target = target.clone()
        data = data.to(train_config.device)
        
        target = target.to(train_config.device)
        
        output = model(data)
        # add loss for each mini batch
        test_loss += F.cross_entropy(output, target).item()
        
        # Score to probability using softmax
        prob = F.softmax(output, dim=1)
        
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1] 
        
        # add correct prediction count
        count_corect_predictions += pred.cpu().eq(indx_target).sum()

    # average over number of mini-batches
    test_loss = test_loss / len(test_loader)  
    
    # average over number of dataset
    accuracy = 100. * count_corect_predictions / len(test_loader.dataset)
    
    return test_loss, accuracy/100.0

# <font style="color:blue">Add histogram of weights</font>

In [23]:
def add_model_weights_as_histogram(model, tb_writer, epoch):
    for name, param in model.named_parameters():
        tb_writer.add_histogram(name.replace('.', '/'), param.data.cpu().abs(), epoch)
    return

# <font style="color:blue">Add Network Graph</font>

In [24]:
def add_network_graph_tensorboard(model, inputs, tb_writer):
    tb_writer.add_graph(model, inputs)
    return

## <font style="color:green">Main Function for Training and Validation</font>

In this section of code, we use the configuration parameters defined above and start the training. Here are the important actions being taken in the code below:

1. Set up system parameters like CPU/GPU, number of threads etc
1. Load the data using dataloaders
1. For each epoch, call train function, and for every test interval, call validation function.
1. Do `scheduler.step()` to update learning rate for next epoch.
1. Set up variables to track loss and accuracy and start training.



In [25]:
def main(model, optimizer, tb_writer, scheduler=None, system_configuration=SystemConfiguration(), 
         training_configuration=TrainingConfiguration(), data_augmentation=False):
    
    # system configuration
    setup_system(system_configuration)

    # batch size
    batch_size_to_set = training_configuration.batch_size
    # num_workers
    num_workers_to_set = training_configuration.num_workers
    # epochs
    epoch_num_to_set = training_configuration.epochs_count

    # if GPU is available use training config, 
    # else lowers batch_size, num_workers and epochs count
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
        batch_size_to_set = 16
        num_workers_to_set = 2

    # data loader
    train_loader, test_loader = get_data(
        batch_size=batch_size_to_set,
        data_root=training_configuration.data_root,
        tb_writer=tb_writer,
        num_workers=num_workers_to_set,
        data_augmentation=data_augmentation
    )
    
    
    # Update training configuration
    training_configuration = TrainingConfiguration(
        device=device,
        batch_size=batch_size_to_set,
        num_workers=num_workers_to_set
    )
        
    # send model to device (GPU/CPU)
    model.to(training_configuration.device)
    
    
    # add network graph with inputs info
    images, labels = next(iter(test_loader))
    images = images.to(training_configuration.device)
    add_network_graph_tensorboard(model, images, tb_writer)

    best_loss = torch.tensor(np.inf)
    
    # epoch train/test loss
    epoch_train_loss = np.array([])
    epoch_test_loss = np.array([])
    
    # epch train/test accuracy
    epoch_train_acc = np.array([])
    epoch_test_acc = np.array([])
    
    add_wrong_prediction_to_tensorboard(model, test_loader, 
                                                training_configuration.device, 
                                                tb_writer, 0, max_images=300)
    
    
    # trainig time measurement
    t_begin = time.time()
    for epoch in range(training_configuration.epochs_count):
        
        # Train
        train_loss, train_acc = train(training_configuration, model, optimizer, train_loader, epoch, tb_writer)
        
        epoch_train_loss = np.append(epoch_train_loss, [train_loss])
        
        epoch_train_acc = np.append(epoch_train_acc, [train_acc])
        
        # add scalar (loss/accuracy) to tensorboard
        tb_writer.add_scalar('Loss/Train',train_loss, epoch)
        tb_writer.add_scalar('Accuracy/Train', train_acc, epoch)

        elapsed_time = time.time() - t_begin
        speed_epoch = elapsed_time / (epoch + 1)
        speed_batch = speed_epoch / len(train_loader)
        eta = speed_epoch * training_configuration.epochs_count - elapsed_time
        
        # add time metadata to tensorboard
        tb_writer.add_scalar('Time/elapsed_time', elapsed_time, epoch)
        tb_writer.add_scalar('Time/speed_epoch', speed_epoch, epoch)
        tb_writer.add_scalar('Time/speed_batch', speed_batch, epoch)
        tb_writer.add_scalar('Time/eta', eta, epoch)
        

        # Validate
        if epoch % training_configuration.test_interval == 0:
            current_loss, current_accuracy = validate(training_configuration, model, test_loader)
            
            epoch_test_loss = np.append(epoch_test_loss, [current_loss])
        
            epoch_test_acc = np.append(epoch_test_acc, [current_accuracy])
            
            # add scalar (loss/accuracy) to tensorboard
            tb_writer.add_scalar('Loss/Validation', current_loss, epoch)
            tb_writer.add_scalar('Accuracy/Validation', current_accuracy, epoch)
            
            # add scalars (loss/accuracy) to tensorboard
            tb_writer.add_scalars('Loss/train-val', {'train': train_loss, 
                                           'validation': current_loss}, epoch)
            tb_writer.add_scalars('Accuracy/train-val', {'train': train_acc, 
                                               'validation': current_accuracy}, epoch)
            
            if current_loss < best_loss:
                best_loss = current_loss
                
            # add wrong predicted image to tensorboard
            add_wrong_prediction_to_tensorboard(model, test_loader, 
                                                training_configuration.device, 
                                                tb_writer, epoch, max_images=300)
        
        # scheduler step/ update learning rate
        if scheduler is not None:
            scheduler.step()
            
        # adding model weights to tensorboard as histogram
        add_model_weights_as_histogram(model, tb_writer, epoch)
        
        # add pr curves to tensor board
        add_pr_curves_to_tensorboard(model, test_loader, 
                                     training_configuration.device, 
                                     tb_writer, epoch, num_classes=3)
        
                
    print("Total time: {:.2f}, Best Loss: {:.3f}".format(time.time() - t_begin, best_loss))
    
    
    
    return model, epoch_train_loss, epoch_train_acc, epoch_test_loss, epoch_test_acc

## <font style="color:green">Optimizer and Scheduler</font>

Let's write optimizer and scheduler as a method because we have to use it in all training experiments. 

In [26]:
def get_optimizer_and_scheduler(model):
    train_config = TrainingConfiguration()

    init_learning_rate = train_config.init_learning_rate

    # optimizer
    optimizer = optim.SGD(
        model.parameters(),
        lr = init_learning_rate,
        momentum = 0.9
    )

    decay_rate = train_config.decay_rate

    lmbda = lambda epoch: 1/(1 + decay_rate * epoch)

    # Scheduler
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda)
    
    return optimizer, scheduler
    


# <font style="color:blue">ResNet Model</font>
We will load the resnet18 model with its pretrained weights.

The layers are configured such that if you pass the transfer_learning flag, it just replaces the last layers of the network. Otherwise, it will retrain all layers, but with the pretrained weights, not from scratch.

In [1]:
def pretrained_resnet18(transfer_learning=True, num_class=3):
    resnet = models.resnet18(pretrained=True)
    
    if transfer_learning:
        for param in resnet.parameters():
            param.requires_grad = False
            
    last_layer_in = resnet.fc.in_features
    resnet.fc = nn.Linear(last_layer_in, num_class)
    
    return resnet

# <font style="color:blue">Transfer Learning</font>


In [None]:
model = pretrained_resnet18(transfer_learning=True)
print(model)
# get optimizer and scheduler
optimizer, scheduler = get_optimizer_and_scheduler(model)

# Tensorboard summary writer
transfer_learning_sw = SummaryWriter('log_resnet18/transfer_learning')   

# train and validate
model, train_loss_exp2, train_acc_exp2, val_loss_exp2, val_acc_exp2 = main(model, 
                                                                           optimizer,
                                                                           transfer_learning_sw,
                                                                           scheduler,
                                                                           data_augmentation=True)
transfer_learning_sw.close()

# <font style="color:blue">Fine-Tuning</font>


In [None]:
model = pretrained_resnet18(transfer_learning=False)
print(model)
# get optimizer and scheduler
optimizer, scheduler = get_optimizer_and_scheduler(model)

# Tensorboard summary writer
fine_tuning_sw = SummaryWriter('log_resnet18/fine_tuning')   

model, train_loss_exp9, train_acc_exp9, val_loss_exp9, val_acc_exp9 = main(model, 
                                                                           optimizer, 
                                                                           fine_tuning_sw,
                                                                           scheduler,
                                                                           data_augmentation=True)

fine_tuning_sw.close()