Using pre-trained object detection models from pytorch  
Fine-tuned on the icub dataset

Reference: https://www.kaggle.com/code/yerramvarun/fine-tuning-faster-rcnn-using-pytorch

In [1]:
from pathlib import Path

images_dir = Path('dataset/icub/Images')
annotations_dir = Path('dataset/icub/Annotations_refined')
manual_subset_dir = Path('dataset/icub/Images_subset_manual')
manual_annotations_dir = Path('dataset/icub/Annotations_manual')

In [2]:
import os
os.listdir(annotations_dir)

['wallet',
 'sunglasses',
 'pencilcase',
 'remote',
 'squeezer',
 'sodabottle',
 'glass',
 'book',
 'mouse',
 'mug',
 'soapdispenser',
 'ringbinder',
 'README.md',
 'hairclip',
 'hairbrush',
 'sprayer',
 'perfume',
 'flower',
 'bodylotion',
 'ovenglove',
 'cellphone']

In [3]:
import torch
import os
from lxml import etree

from torchvision.io import read_image
from torchvision.transforms import v2

class ICubDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, annotations_dir, manual_subset_dir, manual_annotations_dir, transforms):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.manual_subset_dir = manual_subset_dir
        self.manual_annotations_dir = manual_annotations_dir
        self.transforms = transforms

        self.train_indices = []
        self.test_indices = []

        # get all categories = walk under the annotations directory
        self.categories = sorted([x.name for x in Path.iterdir(annotations_dir) if x.is_dir()])
        self.sub_categories = []
        for cat in self.categories:
            self.sub_categories.append(sorted([x.name for x in Path.iterdir(Path(annotations_dir, cat)) if x.is_dir()]))
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs_path = []
        self.annotations_path = []

        for cat, sub_cats in zip(self.categories, self.sub_categories):
            for sub_cat in sub_cats:
                days_dir = [x for x in Path.iterdir(Path(images_dir, cat, sub_cat, 'MIX')) if x.is_dir()]
                for day_dir in days_dir:
                    # get all images' paths and find all files end with .jpg
                    images = sorted(Path.iterdir(Path(day_dir, 'left')))
                    images = [x for x in images if x.name.endswith('.jpg')]

                    # get all annotations (xml file paths)
                    annotations = sorted(Path.iterdir(Path(annotations_dir, cat, sub_cat, 'MIX', day_dir.name, 'left')))

                    annotations_to_be_removed = []

                    # check every annotations to find bboxes with negative values
                    # remove the annotation and the corresponding image if there is a negative value
                    for i, annotation in enumerate(annotations):
                        tree = etree.parse(str(annotation))
                        root = tree.getroot()
                        for obj in root.iter('object'):
                            bndbox = obj.find('bndbox')

                            x_min = int(bndbox.find('xmin').text)
                            y_min = int(bndbox.find('ymin').text)
                            x_max = int(bndbox.find('xmax').text)
                            y_max = int(bndbox.find('ymax').text)

                            # check if the bbox is valid
                            if ((x_min < 0) or (y_min < 0) or (x_max < 0) or (y_max < 0)):
                                # remove the annotation from the list
                                annotations_to_be_removed.append(annotation)

                    # remove the annotations from the list
                    for annotation in annotations_to_be_removed:
                        print(f"Removing {annotation} from annotations")
                        annotations.remove(annotation)
                        # remove the corresponding image from the list
                        image_name = annotation.name[:-4] + '.jpg'
                        image_path = Path.joinpath(images[0].parent, image_name)
                        if image_path in images:
                            print(f"Removing {image_path} from images")
                            images.remove(image_path)


                    # if both have equal number of files -> normal
                    # else: compare the two lists and find the difference
                    if len(images) != len(annotations):
                        image_set = set([x.name[:-4] for x in images])
                        annotation_set = set([x.name[:-4] for x in annotations])
                        diff = image_set ^ annotation_set

                        # for every difference entry, check whether it is missin in image or annotation set
                        for entry in diff:
                            if entry in image_set:      # there is an image but no annotation -> remove the image
                                # remove the image from the list
                                print(f"Removing {Path.joinpath(images[0].parent, entry + '.jpg')} from images")
                                images.remove(Path.joinpath(images[0].parent, entry + '.jpg'))

                            else:                       # there is an annotation but no image -> remove the annotation
                                # remove the annotation from the list
                                print(f"Removing {Path.joinpath(annotations[0].parent, entry + '.xml')} from annotations")
                                annotations.remove(Path.joinpath(annotations[0].parent, entry + '.xml'))
                                
            
                    # for every images, check whether it's in the manual subset
                    # if yes, store the index of the image
                    for i, img_path in enumerate(images):
                        if self.manual_subset_dir.joinpath('/'.join(img_path.parts[3:])).exists():
                            self.test_indices.append(i + len(self.imgs_path))

                            # replace the annotation path to the manual annotation path
                            annotation_path = self.manual_annotations_dir.joinpath('/'.join(img_path.parts[3:])).with_suffix('.xml')
                            if annotation_path.exists():
                                annotations[i] = annotation_path
                        else:   
                            self.train_indices.append(i + len(self.imgs_path))



                    assert len(images) == len(annotations), f"Number of images and annotations do not match in {cat}/{sub_cat}/{day_dir.name}. " + \
                        f"Number of images: {len(images)}, Number of annotations: {len(annotations)}. Difference in set: {set([x.name[:-4] for x in images]) ^ set([x.name[:-4] for x in annotations])}"

                    self.imgs_path += images
                    self.annotations_path += annotations

        self.labels = ['background']        # first label is reserved for background
        for cats in self.categories:
            self.labels.append(cats)

    def __getitem__(self, idx):
        img_path = self.imgs_path[idx]
        annotation_path = self.annotations_path[idx]

        # load image
        image = read_image(str(img_path))
        # load annotation
        tree = etree.parse(str(annotation_path))
        root = tree.getroot()
        # get the bbox of the object
        bbox = []
        label = ''
        for obj in root.iter('object'):
            # get the label of the object
            label = obj.find('category').text

            bndbox = obj.find('bndbox')
            bbox.append([int(bndbox.find('xmin').text),
                         int(bndbox.find('ymin').text),
                         int(bndbox.find('xmax').text),
                         int(bndbox.find('ymax').text)])
            
        # pre-processing
        # reduce the size of the image to 224x224
        # the dimension of the image is represented as (channel, height, width)
        # image_resized = v2.Resize((224, 224))(image)
        # # convert the bbox to the resized image
        # bbox_resized = []
        # for box in bbox:
        #     xmin = int(box[0] * 224.0 / image.shape[2])
        #     ymin = int(box[1] * 224.0 / image.shape[1])
        #     xmax = int(box[2] * 224.0 / image.shape[2])
        #     ymax = int(box[3] * 224.0 / image.shape[1])
        #     bbox_resized.append([xmin, ymin, xmax, ymax])

        # print('Image path:', img_path)
        # print('Annotation path:', annotation_path)
        # print('BBox:', bbox)
        # print('BBox_resized:', bbox_resized)

        target = {}
        target['boxes'] = torch.tensor(bbox, dtype=torch.float32)
        target['labels'] = torch.tensor([self.labels.index(label)], dtype=torch.int64)
        target['area'] = torch.tensor([(box[2] - box[0]) * (box[3] - box[1]) for box in bbox], dtype=torch.float32)
        target['iscrowd'] = torch.tensor([0], dtype=torch.int64)
        target['image_id'] = torch.tensor([idx])

        if self.transforms:
            image_resized, target = self.transforms(image, target)
        else:
            image_resized = image
        
        return image_resized, target
    
    def get_details_from_id(self, idx):
        img_path = self.imgs_path[idx]
        annotation_path = self.annotations_path[idx]

        # load image
        image = read_image(str(img_path))
        # load annotation
        tree = etree.parse(str(annotation_path))
        root = tree.getroot()
        # get the bbox of the object
        bbox = []
        label = ''
        for obj in root.iter('object'):
            # get the label of the object
            label = obj.find('category').text

            bndbox = obj.find('bndbox')
            bbox.append([int(bndbox.find('xmin').text),
                         int(bndbox.find('ymin').text),
                         int(bndbox.find('xmax').text),
                         int(bndbox.find('ymax').text)])
        
        return img_path, annotation_path, image, bbox, label
    
    def __len__(self):
        return len(self.imgs_path)

In [4]:
# Since v0.15.0 torchvision provides new Transforms API to easily write data augmentation pipelines for Object Detection and Segmentation tasks.
# Let’s write some helper functions for data augmentation / transformation:

from torchvision.transforms import v2 as T

def get_transform(train):
    transforms = []
    # if train:
    #     transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [5]:
# check dataset
dataset = ICubDataset(images_dir, annotations_dir, manual_subset_dir, manual_annotations_dir, get_transform(train=False))
print('length of dataset = ', len(dataset), '\n')

# getting the image and target for a test index.
img, target = dataset[96339]
print(img.shape, '\n',target)

Removing dataset/icub/Annotations_refined/bodylotion/bodylotion1/MIX/day1/left/00002946.xml from annotations
Removing dataset/icub/Images/bodylotion/bodylotion1/MIX/day1/left/00002946.jpg from images
Removing dataset/icub/Annotations_refined/bodylotion/bodylotion1/MIX/day1/left/00003198.xml from annotations
Removing dataset/icub/Images/bodylotion/bodylotion1/MIX/day1/left/00003198.jpg from images
Removing dataset/icub/Annotations_refined/bodylotion/bodylotion1/MIX/day1/left/00003258.xml from annotations
Removing dataset/icub/Images/bodylotion/bodylotion1/MIX/day1/left/00003258.jpg from images
Removing dataset/icub/Annotations_refined/bodylotion/bodylotion1/MIX/day1/left/00003280.xml from annotations
Removing dataset/icub/Images/bodylotion/bodylotion1/MIX/day1/left/00003280.jpg from images
Removing dataset/icub/Annotations_refined/bodylotion/bodylotion1/MIX/day2/left/00007918.xml from annotations
Removing dataset/icub/Images/bodylotion/bodylotion1/MIX/day2/left/00007918.jpg from images


In [6]:
dataset.get_details_from_id(96339)

(PosixPath('dataset/icub/Images/ringbinder/ringbinder7/MIX/day3/left/00003509.jpg'),
 PosixPath('dataset/icub/Annotations_refined/ringbinder/ringbinder7/MIX/day3/left/00003509.xml'),
 tensor([[[ 34,  36,  39,  ...,  66,  66,  66],
          [ 35,  37,  41,  ...,  66,  66,  66],
          [ 35,  37,  40,  ...,  66,  66,  66],
          ...,
          [ 33,  33,  33,  ...,  81,  79,  78],
          [ 33,  33,  33,  ...,  82,  80,  79],
          [ 33,  33,  33,  ...,  82,  81,  80]],
 
         [[ 21,  23,  26,  ...,  65,  65,  65],
          [ 22,  24,  28,  ...,  65,  65,  65],
          [ 24,  26,  29,  ...,  65,  65,  65],
          ...,
          [ 35,  35,  35,  ...,  87,  85,  84],
          [ 35,  35,  35,  ...,  88,  86,  85],
          [ 35,  35,  35,  ...,  88,  87,  86]],
 
         [[ 15,  17,  20,  ...,  71,  71,  71],
          [ 16,  18,  22,  ...,  71,  71,  71],
          [ 20,  22,  23,  ...,  71,  71,  71],
          ...,
          [ 34,  34,  34,  ...,  99,  97,  96]

In [7]:
dataset.labels, len(dataset.labels)

(['background',
  'bodylotion',
  'book',
  'cellphone',
  'flower',
  'glass',
  'hairbrush',
  'hairclip',
  'mouse',
  'mug',
  'ovenglove',
  'pencilcase',
  'perfume',
  'remote',
  'ringbinder',
  'soapdispenser',
  'sodabottle',
  'sprayer',
  'squeezer',
  'sunglasses',
  'wallet'],
 21)

Load the pre-trained faster rcnn network

In [None]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, fasterrcnn_mobilenet_v3_large_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# using mobile net for faster training, while have decent accuracy
model = fasterrcnn_mobilenet_v3_large_fpn()

n_classes = len(dataset.labels)

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)

In [9]:
from torchinfo import summary

summary(model, input_size=(8, 3, 224, 224))

Layer (type:depth-idx)                                  Output Shape              Param #
FasterRCNN                                              [100, 4]                  --
├─GeneralizedRCNNTransform: 1-1                         [8, 3, 800, 800]          --
├─BackboneWithFPN: 1-2                                  [8, 256, 13, 13]          --
│    └─IntermediateLayerGetter: 2-1                     [8, 960, 25, 25]          --
│    │    └─Conv2dNormActivation: 3-1                   [8, 16, 400, 400]         (432)
│    │    └─InvertedResidual: 3-2                       [8, 16, 400, 400]         (400)
│    │    └─InvertedResidual: 3-3                       [8, 24, 200, 200]         (3,136)
│    │    └─InvertedResidual: 3-4                       [8, 24, 200, 200]         (4,104)
│    │    └─InvertedResidual: 3-5                       [8, 40, 100, 100]         (9,960)
│    │    └─InvertedResidual: 3-6                       [8, 40, 100, 100]         (20,432)
│    │    └─InvertedResidual: 3-7

Training

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
from sklearn.model_selection import train_test_split

BATCH_SIZE = 8

def collate_fn(batch):
  return tuple(zip(*batch))

# create train and validation set
train_indices, validation_indices = train_test_split(dataset.train_indices, test_size=0.2, random_state=42)
dataset_train = torch.utils.data.Subset(dataset, train_indices)
dataset_valid = torch.utils.data.Subset(dataset, validation_indices)
dataset_test = torch.utils.data.Subset(dataset, dataset.test_indices)

# create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)
data_loader_valid = torch.utils.data.DataLoader(
    dataset_valid,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,
)

In [12]:
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): FrozenBatchNorm2d(16, eps=1e-05)
        (2): Hardswish()
      )
      (1): InvertedResidual(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
            (1): FrozenBatchNorm2d(16, eps=1e-05)
            (2): ReLU(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(16, eps=1e-05)
          )
        )
      )
      (2): InvertedResidual(
        (block):

In [13]:
N_EPOCHS = 3

# construct optimizer
params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.AdamW(params, lr=5e-3)
# following the tutorial first. Not sure if this is the best optimizer
optimizer = torch.optim.SGD(params, lr=0.005,
                        momentum=0.9,
                        weight_decay=0.0005)

# construct learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=N_EPOCHS, gamma=0.5)

In [14]:
from tqdm import tqdm
import matplotlib.pyplot as plt

'''
Function to train the model over one epoch.
'''
def train_one_epoch(model, optimizer, data_loader, device):
  train_loss_list = []

  tqdm_bar = tqdm(data_loader, total=len(data_loader))
  for idx, data in enumerate(tqdm_bar):
    optimizer.zero_grad()
    images, targets = data

    # print(targets)

    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  # targets = {'boxes'=tensor, 'labels'=tensor}

    losses = model(images, targets)

    loss = sum(loss for loss in losses.values())
    loss_val = loss.item()
    train_loss_list.append(loss.detach().cpu().numpy())

    loss.backward()
    optimizer.step()

    tqdm_bar.set_description(desc=f"Training Loss: {loss_val:.3f}")

  return train_loss_list

'''
Function to validate the model
'''
def evaluate(model, data_loader_test, device):
    val_loss_list = []

    tqdm_bar = tqdm(data_loader_test, total=len(data_loader_test))

    for i, data in enumerate(tqdm_bar):
        images, targets = data

        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        with torch.no_grad():
            losses = model(images, targets)

        loss = sum(loss for loss in losses.values())
        loss_val = loss.item()
        val_loss_list.append(loss_val)

        tqdm_bar.set_description(desc=f"Validation Loss: {loss:.4f}")
    return val_loss_list

'''
Function to plot training and valdiation losses and save them in `output_dir'
'''
def plot_loss(train_loss, valid_loss):
    figure_1, train_ax = plt.subplots()
    figure_2, valid_ax = plt.subplots()

    train_ax.plot(train_loss, color='blue')
    train_ax.set_xlabel('Iteration')
    train_ax.set_ylabel('Training Loss')

    valid_ax.plot(valid_loss, color='red')
    valid_ax.set_xlabel('Iteration')
    valid_ax.set_ylabel('Validation loss')

    # figure_1.savefig(f"{OUTPUT_DIR}/train_loss.png")
    # figure_2.savefig(f"{OUTPUT_DIR}/valid_loss.png")

In [15]:
from tqdm import tqdm

loss_dict = {'train_loss': [], 'valid_loss': []}
best_model = None

for epoch in range(N_EPOCHS):
    print("----------Epoch {}----------".format(epoch+1))

    # Train the model for one epoch
    train_loss_list = train_one_epoch(model, optimizer, data_loader_train, device)
    loss_dict['train_loss'].extend(train_loss_list)

    lr_scheduler.step()

    # Run evaluation
    valid_loss_list = evaluate(model, data_loader_valid, device)
    loss_dict['valid_loss'].extend(valid_loss_list)

    # store the best model
    if best_model is None or min(valid_loss_list) < min(loss_dict['valid_loss']):
        best_model = model.state_dict()

    # Svae the model ckpt after every epoch
    # ckpt_file_name = f"{OUTPUT_DIR}/epoch_{epoch+1}_model.pth"
    # torch.save({
    #     'epoch': epoch+1,
    #     'model_state_dict': model.state_dict(),
    #     'optimizer_state_dict': optimizer.state_dict(),
    #     'loss_dict': loss_dict
    # }, ckpt_file_name)


----------Epoch 1----------


Training Loss: 0.861: 100%|██████████| 13016/13016 [18:23<00:00, 11.80it/s]
Validation Loss: 0.6024: 100%|██████████| 3254/3254 [03:40<00:00, 14.73it/s]


----------Epoch 2----------


Training Loss: 0.548: 100%|██████████| 13016/13016 [18:11<00:00, 11.93it/s]
Validation Loss: 0.6554: 100%|██████████| 3254/3254 [03:46<00:00, 14.37it/s]


----------Epoch 3----------


Training Loss: 0.445: 100%|██████████| 13016/13016 [18:12<00:00, 11.92it/s]
Validation Loss: 0.4722: 100%|██████████| 3254/3254 [03:51<00:00, 14.03it/s]
