## Introduction

This notebook is a very minimal training notebook that will show you how to fine-tune a `resnet50` model for Image segmentation task (given an image, we predict it's mask).

The point of this notebook is not to serve as a complete end-to-end training example, but to serve a reference on how to write training scripts for Image-based Lance datasets.

In [1]:
import cv2
import time
import lance
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision import datasets, tv_tensors
import torchvision.models.segmentation as models

from pycocotools import mask as maskUtils

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter('ignore')

## Utility function

Below are some utility functions to make mask using the mask coordinates and to resize the mask and image together

In [2]:
def convert_dict_to_list(input_dict):
    keys = list(input_dict.keys())
    values = list(input_dict.values())
    
    result = [dict(zip(keys, sublist)) for sublist in zip(*values)]
    return result

def get_mask(img, segmentation, category):
    h, w = img.shape[:2]

    if type(segmentation) == list:
        rles = maskUtils.frPyObjects(segmentation, h, w)
        rle = maskUtils.merge(rles)
    else:
        raise ValueError(f"Unknown annotation type. Expected list, recieved '{type(segmentation)}'")

    m = maskUtils.decode(rle)
    mask = np.zeros((h, w), dtype=np.uint8)
    mask[:, :] += (mask == 0) * (m * category)

    return torch.from_numpy(mask)

def resize_image_mask(image, mask, size):
    # Resize the image using bilinear interpolation
    resized_image = torch.nn.functional.interpolate(image.permute(2, 0, 1).unsqueeze(0), size=size, mode='bilinear', align_corners=False)[0]
    
    # Resize the mask using nearest-neighbor interpolation
    resized_mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=size, mode='nearest')[0, 0]
    
    return resized_image, resized_mask

## COCOLanceDataset

This is the Custom Dataset that allows to loads COCO2017 Lance dataset.

We pre-load the annotations to match what PyTorch's COCO dataloader (by loading in the annotations JSON file). This makes the notebook comparable to equivalent-PyTorch's COCO dataloader. However, in case you don't want this functionality, you can remove it completely and replace it with the logic to fetch the annotations from the Lance dataset (similar how we are fetching the images).

In this dataset we load the images and the segmentation coordinates from the lance dataset and then we generate a mask using the `pycocoutils` package and resize both the image and mask to be of same heigh and width (so they can be stacked by PyTorch's default data collator, if you don't want it, use a custom data collator).

In [4]:
class COCOLanceDataset(Dataset):
    def __init__(self, dataset_path, resize=(600, 600), transforms=None):
        self.ds = lance.dataset(dataset_path)
        self.resize = resize
        self.transforms = transforms
        self.annotations = self.preload_annotations()

    def preload_annotations(self, exclude=['image', 'width', 'height', 'image_path']):
        """
        This will load all the annotations in memory which speeds up operations 
        drastically during the training but adds some overhead cost during init
        """
        print("Preloading all annotations in the memory.")
        idxs = [x for x in range(self.ds.count_rows())]
        cols = list(self.ds.take([0]).to_pydict().keys())
        cols = [col for col in cols if col not in exclude]
        return self.ds.take(idxs, columns=cols).to_pylist()

    def __len__(self):
        return self.ds.count_rows()

    def _load_image(self, idx):
        raw_img = self.ds.take([idx], columns=['image']).to_pydict()
        raw_img = np.frombuffer(b''.join(raw_img['image']), dtype=np.uint8)
        img = cv2.imdecode(raw_img, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        return torch.from_numpy(img)

    def _load_target(self, idx):
        return self.annotations[idx]

    def __getitem__(self, idx):
        image = self._load_image(idx)
        raw_anns = self._load_target(idx)

        anns = dict(
            image_id=[raw_anns['image_id']]*len(raw_anns['bbox']),
            segmentation=raw_anns['segmentation'], 
            bbox=raw_anns['bbox'], 
            area=raw_anns['area'], 
            iscrowd=raw_anns['is_crowd'],
            category_id=raw_anns['category_id']
        )
        
        anns = convert_dict_to_list(anns)

        # Get only the first mask for the image for the sake of demonstration
        mask = get_mask(image, anns[0]['segmentation'], anns[0]['category_id'])

        # Resize both image and mask accordingly
        image, mask = resize_image_mask(image, mask, size=self.resize)

        # transforms
        if self.transforms:
            image = self.transforms(image.float())

        return image, mask

## Training the Model

This training part of this notebook is fairly straightforward, we begin by defining our dataset, a very normalisation transformation and then defining the dataloader.

Once that is done and the annotations are pre-loaded in the memory (which takes a little while because there are 100K+ of them!), we define a model, loss function and optimizer and then train the model.

In the training logs below, you will see the reported per-batch time (to train) and total time (to train).

In [5]:
# Define the dataset and the dataloader
ds = COCOLanceDataset(
    dataset_path='coco2017_train_lance/coco2017_train_new.lance/', 
    transforms=transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
)

dl = DataLoader(
    ds,
    shuffle=False,
    batch_size=8,
    pin_memory=True,
)

Preloading all annotations in the memory.


In [7]:
def get_model(num_classes=134):
    model = models.deeplabv3_resnet50(pretrained=True)
    model.classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    return model

In [8]:
model = get_model().to("cuda:0")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train_model(model, dataloader, criterion, optimizer, num_epochs=5, batch_to_train=128):
    model.train()
    total_start = time.time()
    loss = torch.tensor([0])
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        pbar = tqdm(enumerate(dataloader), total=batch_to_train)
        total_batch_start = time.time()
        
        for idx, (images, masks) in pbar:
            if idx+1 == batch_to_train:
                break

            optimizer.zero_grad()
            
            outputs = model(images.to("cuda:0"))['out']
            loss = criterion(outputs, masks.long().to("cuda:0"))
            
            pbar.set_description(f"loss: {loss.item():.4f}")
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        per_batch_time = (time.time() - total_batch_start) / batch_to_train

        print(f'Epoch {epoch+1} ({batch_to_train} batches) | Loss: {running_loss/batch_to_train} | Avg Per-batch time: {per_batch_time:.4f} seconds')

    total_time = time.time() - total_start
    print(f"Total time taken for {num_epochs} epochs and {batch_to_train} batches: {total_time/60:.4f} mins")

train_model(model, dl, criterion, optimizer)

loss: 1.1363:  99%|█████████▉| 127/128 [01:59<00:00,  1.07it/s]


Epoch 1 (128 batches) | Loss: 1.3862293751444668 | Avg Per-batch time: 0.9303 seconds


loss: 1.0304:  99%|█████████▉| 127/128 [01:58<00:00,  1.07it/s]


Epoch 2 (128 batches) | Loss: 0.7514677769504488 | Avg Per-batch time: 0.9258 seconds


loss: 0.9979:  99%|█████████▉| 127/128 [01:58<00:00,  1.07it/s]


Epoch 3 (128 batches) | Loss: 0.735734045621939 | Avg Per-batch time: 0.9240 seconds


loss: 1.0773:  99%|█████████▉| 127/128 [01:58<00:00,  1.07it/s]


Epoch 4 (128 batches) | Loss: 0.726858394453302 | Avg Per-batch time: 0.9255 seconds


loss: 0.9969:  99%|█████████▉| 127/128 [01:58<00:00,  1.07it/s]

Epoch 5 (128 batches) | Loss: 0.7257018785458058 | Avg Per-batch time: 0.9239 seconds
Total time taken for 5 epochs and 128 batches: 9.8766 mins



