In [153]:
import numpy as np
from lxml import objectify
import os
import torch

from PIL import Image
from torchvision.io import read_image
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.v2 import functional as F
from torchvision import tv_tensors
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [151]:
class HRSC2016(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None, imageset='train'):
        self.root = root
        self.transforms = transforms

        with open(os.path.join(root, "ImageSets", f'{imageset}.txt'), 'r') as f:
            required_imgs = set(f.read().split())

        self.imgs = list(filter(lambda x: x.split('.')[0] in required_imgs, sorted(os.listdir(os.path.join(root, "AllImages")))))
        self.annots = list(filter(lambda x: x.split('.')[0] in required_imgs, sorted(os.listdir(os.path.join(root, "Annotations")))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "AllImages", self.imgs[idx])
        annot_path = os.path.join(self.root, "Annotations", self.annots[idx])
        
        img = F.pil_to_tensor(Image.open(img_path))
        with open(annot_path, 'rb') as f:
            annot_root = objectify.fromstring(f.read())

        num_obj = len(annot_root.object)
        
        bbs = []
        for obj in annot_root.object:
            bbox_xml = obj.bndbox
            bbox = [int(bbox_xml.xmin), int(bbox_xml.ymin), int(bbox_xml.xmax), int(bbox_xml.ymax)]
            bbox = torch.tensor(bbox)
            bbs.append(bbox)

        bboxes = torch.stack(bbs)
        areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        
        bboxes = tv_tensors.BoundingBoxes(bboxes, format='XYXY', canvas_size=F.get_size(img))
        labels = torch.ones((num_obj,), dtype=torch.int64)
        iscrowd = torch.zeros((num_obj,), dtype=torch.int64)
        
        img = tv_tensors.Image(img)
        target = dict()
        target['boxes'] = bboxes
        target['labels'] = labels
        target['image_id'] = idx
        target['area'] = areas
        target['iscrowd'] = iscrowd
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [152]:
dataset = HRSC2016('/kaggle/input/hrsc2016-ms-dataset', imageset='trainval')

len(dataset)

1070

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

num_classes = 2  # ship + background
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)