In [11]:
%matplotlib inline

# Fine-Tuning and Feature Extraction with PyTorch Models




In this tutorial, we will take a deeper look at how to fine-tune and perform feature extraction using `torchvision models <https://pytorch.org/docs/stable/torchvision/models.html>`, all of which have been pretrained on the 1000-class ImageNet dataset. This tutorial will provide an in-depth understanding of how to work with several modern Convolutional Neural Network (CNN) architectures and develop intuition for fine-tuning any PyTorch model. Since each model architecture differs, there is no universal fine-tuning code that applies to all cases. Instead, researchers must inspect the architecture and make custom adjustments for each model.

In this document, we will explore two types of transfer learning: **fine-tuning** and **feature extraction**.

- **Fine-tuning**: We start with a pretrained model and update *all* of the model’s parameters for the new task, essentially retraining the entire model.
- **Feature extraction**: We begin with a pretrained model and only update the final layer's weights to make predictions for the new task. It is called feature extraction because we use the pretrained CNN as a fixed feature extractor, modifying only the output layer.

For more technical details about transfer learning, you can refer to these resources: [Transfer Learning in CS231n](http://cs231n.github.io/transfer-learning/).

#### General Steps for Transfer Learning

Both transfer learning methods follow these common steps:

1. Initialize the pretrained model.
2. Reshape the final layer(s) to match the number of output classes in the new dataset.
3. Define which parameters will be updated during the optimization process.
4. Run the training phase.

In [12]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

PyTorch Version:  1.12.1+cu113
Torchvision Version:  0.13.1+cu113


Inputs
------

Here are all of the parameters to change for the run. We will use the
*hymenoptera_data* dataset which can be downloaded
`here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`__.
This dataset contains two classes, **bees** and **ants**, and is
structured such that we can use the
`ImageFolder <https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder>`__
dataset, rather than writing our own custom dataset. Download the data
and set the ``data_dir`` input to the root directory of the dataset. The
``model_name`` input is the name of the model you wish to use and must
be selected from this list:

::

   [resnet, alexnet, vgg, squeezenet, densenet, inception]

The other inputs are as follows: ``num_classes`` is the number of
classes in the dataset, ``batch_size`` is the batch size used for
training and may be adjusted according to the capability of your
machine, ``num_epochs`` is the number of training epochs we want to run,
and ``feature_extract`` is a boolean that defines if we are finetuning
or feature extracting. If ``feature_extract = False``, the model is
finetuned and all model parameters are updated. If
``feature_extract = True``, only the last layer parameters are updated,
the others remain fixed.

In [13]:
# Hyper-parameters

model_name = "resnet50"

# Number of classes in the dataset

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 15

device = "cuda"

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

## Dataloader

In [20]:
import xml.etree.ElementTree as ET
import torch
from torch.utils.data import Dataset, DataLoader
import glob
import pandas as pd
from PIL import Image 
from torchvision.transforms import InterpolationMode
from torchvision import datasets, transforms

# Custom dataset class to handle the XML annotations
class RetailProductDataset(Dataset):
    
    def _convert_image_to_rgb(self, image):
        return image.convert("RGB")
    
    def get_transform(self):
        if self.partition == "train":
            return transforms.Compose([
                self._convert_image_to_rgb,
                transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
                transforms.RandomCrop(224),
                # Randomly flip the image horizontally
                transforms.RandomHorizontalFlip(p=0.5),
                # Randomly flip the image horizontally
                # transforms.ColorJitter(
                #     brightness=0.4, contrast=0.2, saturation=0.2),
                # Randomly flip the image horizontallys
                transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.2),
                transforms.ToTensor(),
                # # Normalize with mean and std of imagenet
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                    (0.26862954, 0.26130258, 0.27577711)),
            ])
        else:
            return transforms.Compose([
                self._convert_image_to_rgb,
                transforms.Resize((224,224), interpolation=InterpolationMode.BICUBIC),
                # transforms.CenterCrop(224),
                transforms.ToTensor(),
                # Normalize with mean and std of imagenet
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                    (0.26862954, 0.26130258, 0.27577711)),
            ])

    def denormalize(self, image):
        mean, std = (0.48145466, 0.4578275,
                     0.40821073), (0.26862954, 0.26130258, 0.27577711)
        mean = torch.tensor(mean).view(1, 3, 1, 1)
        std = torch.tensor(std).view(1, 3, 1, 1)

        return image * std + mean
    
    def __init__(self, root_folder, partition, bbox = False):
        self.data = []
        self.root_folder = root_folder
        self.partition = partition
        self.bbox = bbox
        self.annotation_folder  = f"{self.root_folder}/annotations/{self.partition}/"
        self.image_folder  = f"{self.root_folder}/images/{self.partition}/"
        self.dataframe = self.generate_dataframe()
        self.list_cls = list(self.dataframe["cls"].unique())
        self.cls_dict = {name:idx for idx, name in enumerate(sorted(self.list_cls))}
        
    def generate_dataframe(self):
        # print(self.image_folder + "*.jpg")
        image_paths = [x for x in glob.glob(self.image_folder + "*.jpg")]
        annotation_paths = [x for x in glob.glob(self.annotation_folder + "*.xml")]
        image_paths = sorted(image_paths, key=lambda x: int(''.join(filter(str.isdigit, x))))
        annotation_paths = sorted(annotation_paths, key=lambda x: int(''.join(filter(str.isdigit, x))))
        assert np.all(i == a for i, a in zip(image_paths,annotation_paths))
        # print(image_paths)
        # print(annotation_paths)

        df = pd.DataFrame(columns = ["image_path", "width", "height", "cls", "bbox"])
        for i,a in zip(image_paths,annotation_paths):
            annotations = self.parse_xml(a)
            # print(i,a)
            df = df._append({"image_path": i, "width":annotations["width"], 
                            "height":annotations["height"], "cls":annotations["cls"], "bbox":annotations["bbox"]},ignore_index = True)
        return df
        
    def parse_xml(self, xml_file):
        # Parse XML content
        tree = ET.parse(xml_file)
        root = tree.getroot()

        # Extract the relevant fields from the XML
        # filename = root.find('filename').text
        # path = root.find('path').text
        size = root.find('size')
        width = int(size.find('width').text)
        height = int(size.find('height').text)
        channels = int(size.find('depth').text)
        
        ## Object Detection info
        obj = root.find('object')
        cls = obj.find('name').text
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)

        # Store the parsed data
        return {
            'width': width,
            'height': height,
            'channels': channels,
            'cls': cls,
            'bbox': [xmin, ymin, xmax, ymax]
        }

    def __datasetclasses__(self):
        return len(self.list_cls)
    
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        item = self.dataframe.loc[idx]
        image_path = item["image_path"]  # read path
        bbox_items = item["bbox"]
        image = Image.open(image_path)
       
        if self.bbox:
            ## crop bbox
            print("-----------------------------")
            print(image_path)
            print(os.path.basename(image_path))
            print(image.size)
            print((bbox_items[0],bbox_items[1],bbox_items[2],bbox_items[3]))
            print("-----------------------------")
            image  = image.crop((bbox_items[0],bbox_items[1],bbox_items[2],bbox_items[3]))

        # mask_image = Image.open(mask_name)
        try:
            transform = self.get_transform()
            image = transform(image)
        except Exception as e:
            # print(e)
            print("\n\n NAME", image_path)
        cls = self.cls_dict[item["cls"]]
        return {"image_path" : image_path, "image": image, "cls": cls}


In [75]:
# Create a dataset instance and wrap it in a DataLoader
root_folder = 'C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products/'  # path to your xml file
out_folder = 'C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/CNN_finetuning/'  # path to your xml file

train_dataset = RetailProductDataset(root_folder=root_folder,partition="train",bbox=True)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
print(len(train_dataloader))


test_dataset = RetailProductDataset(root_folder=root_folder,partition="test",bbox=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(len(test_dataloader))

# Display first item in dataloader
for idx, data in enumerate(test_dataloader):
    first_item = data
    image_tensor = data["image"]
    image_tensor = test_dataset.denormalize(image_tensor)
    image_tensor = image_tensor[0]
    to_pil = transforms.ToPILImage()
    image_pil = to_pil(image_tensor)
    # print(data["image_path"])
    path_base = os.path.basename(data["image_path"][0])
    print(path_base)
    image_pil.save(f'{out_folder}/test/{path_base}')
    # if idx > 10:
    #     break
    
dataloaders_dict = {"train":train_dataloader, "test":test_dataloader}

294
86
-----------------------------
C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products//images/test\aqua (1).jpg
aqua (1).jpg
(512, 512)
(191, 142, 257, 332)
-----------------------------
aqua (1).jpg
-----------------------------
C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products//images/test\chitato (1).jpg
chitato (1).jpg
(512, 512)
(86, 171, 345, 376)
-----------------------------
chitato (1).jpg
-----------------------------
C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products//images/test\indomie (1).jpg
indomie (1).jpg
(512, 512)
(91, 187, 312, 354)
-----------------------------
indomie (1).jpg
-----------------------------
C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products//images/test\mix (1).jpg
mi

In [34]:
image_path = "C:/Users/Chiqu/Documents/GitHub/Deep-Learning-and-Computer-Vision-for-Business/02-Pytorch and CV/datasets/retail_products//images/test\pepsodent (1).jpg"

# pepsodent (1).jpg
# (512, 512)
# (95, 252, 444, 359)
image = Image.open(image_path)
image.show()  # This will display the cropped image

cropped_image  = image.crop((60, 200, 444, 300))
cropped_image.show()

### Explanation of the `train_model` function

This function is designed to train a deep learning model using PyTorch. It handles the training and validation phases for multiple epochs and tracks the best-performing model based on validation accuracy. Below is an explanation of each part of the function:

1. **Function Arguments:**
   - `model`: The model to be trained.
   - `dataloaders`: A dictionary containing 'train' and 'val' dataloaders, which provide batches of data for training and validation.
   - `criterion`: The loss function to be optimized.
   - `optimizer`: The optimization algorithm used to update model weights (e.g., SGD, Adam).
   - `num_epochs`: The number of epochs (iterations over the entire dataset) to train the model (default: 25).
   - `is_inception`: A flag indicating whether the model is Inception, as it has special handling due to auxiliary outputs during training.

2. **Initial Setup:**
   - `since`: Records the starting time to measure how long training takes.
   - `val_acc_history`: A list to store the validation accuracy at the end of each epoch.
   - `best_model_wts`: Deep copy of the model's initial weights, used to save the best-performing model.
   - `best_acc`: Keeps track of the highest validation accuracy achieved.

3. **Training Loop:**
   - The outer loop runs over a number of epochs.
   - For each epoch, it prints the current epoch and separates the training and validation phases.
   - For each phase ('train' or 'val'), it either sets the model to training mode (`model.train()`) or evaluation mode (`model.eval()`).

4. **Batch Loop:**
   - The inner loop iterates over batches of input data (`inputs`) and their corresponding labels (`labels`).
   - The inputs and labels are moved to the device (e.g., GPU or CPU).
   - The gradients of the model are reset using `optimizer.zero_grad()`.

5. **Forward Pass:**
   - In the forward pass, the model computes predictions on the inputs.
   - For Inception models, the loss is calculated using both the primary output and the auxiliary output during training.
   - The loss is computed using the provided loss function (`criterion`).

6. **Backward Pass & Optimization:**
   - For the 'train' phase, backpropagation is performed (`loss.backward()`) to compute gradients, and the optimizer updates the model's parameters (`optimizer.step()`).

7. **Tracking Performance:**
   - After processing each batch, the loss and number of correct predictions are accumulated to calculate the loss and accuracy for the entire epoch.
   - The epoch's loss and accuracy are printed for both the training and validation phases.

8. **Model Checkpointing:**
   - If the validation accuracy for the current epoch exceeds the best accuracy observed so far, the model's weights are saved as the best model.

9. **Completion:**
   - After training completes, the total training time is printed.
   - The function returns the model with the best validation accuracy and a history of validation accuracies.


In [16]:
def train_model(model, dataloaders, criterion, optimizer, device, num_epochs=25):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for item in dataloaders[phase]:
                # print(item)
                inputs = item["image"]
                labels = item["cls"]

                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'test':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

### Initializing and Reshaping the ResNet50 Network

ResNet, introduced in the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385), is a popular architecture for deep learning tasks. The network consists of several variants, including ResNet18, ResNet34, ResNet50, ResNet101, and ResNet152, all available in `torchvision.models`. In this example, we focus on **ResNet50**.

Since the ResNet models are pretrained on the ImageNet dataset (which has 1000 classes), the final fully connected layer (`fc`) has 1000 output features. When working with a new dataset, we need to reshape this layer to match the number of classes in the new task. 

For ResNet50, the final fully connected layer looks like this:

```python
(fc): Linear(in_features=2048, out_features=1000, bias=True)
```

To adapt ResNet50 for our new task, we need to reinitialize `model.fc` to be a `Linear` layer with 2048 input features and the desired number of output classes, `num_classes`. Here is the code to do that:

```python
model.fc = nn.Linear(2048, num_classes)
```

In [17]:
def initialize_model(num_classes, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = models.resnet50(pretrained=use_pretrained)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    input_size = 224
    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(train_dataset.__datasetclasses__(), use_pretrained=True)

# Print the model we just instantiated
print(model_ft)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

Create the Optimizer
--------------------

Now that the model structure is correct, the final step for finetuning
and feature extracting is to create an optimizer that only updates the
desired parameters. Recall that after loading the pretrained model, but
before reshaping, if ``feature_extract=True`` we manually set all of the
parameter’s ``.requires_grad`` attributes to False. Then the
reinitialized layer’s parameters have ``.requires_grad=True`` by
default. So now we know that *all parameters that have
.requires_grad=True should be optimized.* Next, we make a list of such
parameters and input this list to the SGD algorithm constructor.

To verify this, check out the printed parameters to learn. When
finetuning, this list should be long and include all of the model
parameters. However, when feature extracting this list should be short
and only include the weights and biases of the reshaped layers.




In [18]:
# Send the model to GPU
model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
print("Params to learn:")

for name,param in model_ft.named_parameters():
    if "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
    if param.requires_grad == True:
        print("\t",name)


# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(params_to_update, lr=0.001)

Params to learn:
	 fc.weight
	 fc.bias


Run Training and Validation Step
--------------------------------

Finally, the last step is to setup the loss for the model, then run the
training and validation function for the set number of epochs. Notice,
depending on the number of epochs this step may take a while on a CPU.
Also, the default learning rate is not optimal for all of the models, so
to achieve maximum accuracy it would be necessary to tune for each model
separately.




In [19]:
# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, device=device, num_epochs=200)

Epoch 0/199
----------
train Loss: 2.0547 Acc: 0.1463
test Loss: 3.5464 Acc: 0.2442

Epoch 1/199
----------
train Loss: 2.0358 Acc: 0.1565
test Loss: 3.3896 Acc: 0.2791

Epoch 2/199
----------
train Loss: 2.0173 Acc: 0.1633
test Loss: 3.3235 Acc: 0.3605

Epoch 3/199
----------
train Loss: 2.0042 Acc: 0.1599
test Loss: 3.3610 Acc: 0.3488

Epoch 4/199
----------
train Loss: 1.9902 Acc: 0.1735
test Loss: 3.1791 Acc: 0.3605

Epoch 5/199
----------
train Loss: 1.9737 Acc: 0.1803
test Loss: 3.2267 Acc: 0.3605

Epoch 6/199
----------
train Loss: 1.9620 Acc: 0.1803
test Loss: 3.1210 Acc: 0.3488

Epoch 7/199
----------
train Loss: 1.9444 Acc: 0.1871
test Loss: 2.9424 Acc: 0.3488

Epoch 8/199
----------
train Loss: 1.9307 Acc: 0.1871
test Loss: 2.9781 Acc: 0.3488

Epoch 9/199
----------
train Loss: 1.9127 Acc: 0.1871
test Loss: 2.8452 Acc: 0.3605

Epoch 10/199
----------
train Loss: 1.9024 Acc: 0.1871
test Loss: 2.9598 Acc: 0.3488

Epoch 11/199
----------
train Loss: 1.8880 Acc: 0.1905
test Loss