In [None]:
import os
import pathlib
from enum import Enum
from typing import Any

import numpy as np
import pandas as pd
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.transforms import Compose, ToTensor

In [None]:
datasets_dir = pathlib.Path("../datasets")

In [None]:
def read_annotations(path: pathlib.Path) -> tuple[list[str], list[int]]:
    """Reads dataset annotations from a CSV file.

    Args:
        path (pathlib.Path): Path to the CSV file.

    Returns:
        tuple[list[str], list[int]]: A tuple containing the labels and bounding boxes.
    """
    df = pd.read_csv(path, header=0)
    labels: list[str] = df["class"].tolist()
    boxes: list[int] = df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()
    return labels, boxes

In [None]:
def map_class_to_int(labels: list[str], mapping: dict[int, str]) -> list[int]:
    """Maps class labels to integer values.

    Args:
        labels (list[str]): List of class labels.
        mapping (dict[int, str]): A dictionary mapping class labels to integer values.

    Returns:
        list[int]: A list of integer values.
    """
    keys = list(mapping.keys())
    vals = list(mapping.values())

    return [keys[vals.index(label)] for label in labels]

In [None]:
for dataset_num in range(1, 6):
    dataset_dir = datasets_dir / f"dataset_{dataset_num:02}"
    images_dir = dataset_dir / "images"
    annotations_csv = dataset_dir / "annotations.csv"
    annotations_pt = dataset_dir / "annotations.pt"

    labels, boxes = read_annotations(annotations_csv)

    targets = {
        "labels": labels,
        "boxes": boxes
    }

    files = [file for file in images_dir.iterdir() if file.is_file()
             and file.suffix == ".jpg"]

    torch.save(targets, annotations_pt)
    data = torch.load(annotations_pt)
    assert targets == data, f"Targets and loaded data are not equal for dataset {dataset_num:02}."

In [None]:
class Column(Enum):
    FILENAME = 0
    WIDTH = 1
    HEIGHT = 2
    CLASS = 3
    XMIN = 4
    YMIN = 5
    XMAX = 6
    YMAX = 7

In [None]:
class OrbTrackingDataset(Dataset):
    def __init__(
        self,
        annotations_file: pathlib.Path,
        img_dir: pathlib.Path,
        mapping: dict[int, str],
        transform=None
    ) -> None:
        self.annotations = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.mapping = mapping

    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> Any:
        filename = str(self.annotations.iloc[idx, Column.FILENAME.value])
        img_path = os.path.join(
            self.img_dir, filename
        )
        image = read_image(img_path)
        label = str(self.annotations.iloc[idx, Column.CLASS.value])
        classname = label

        labels = map_class_to_int([label], self.mapping)
        labels = torch.tensor(labels, dtype=torch.int32)

        boxes = self.annotations.iloc[idx,
                                      Column.XMIN.value:Column.YMAX.value+1].values.tolist()
        boxes = torch.tensor([boxes], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        target = {
            "labels": labels,
            "boxes": boxes
        }

        target = {key: value.numpy() for key, value in target.items()}

        return {
            "x": image,
            "y": target,
            "x_name": img_path,
            "y_name": classname
        }

In [None]:
mapping = {
    0: "background",
    1: "orb"
}

In [None]:
dataset = OrbTrackingDataset(
    annotations_file=datasets_dir / "dataset_01" / "annotations.csv",
    img_dir=datasets_dir / "dataset_01" / "images",
    mapping=mapping
)

In [None]:
print(dataset[0]["x"].shape)

In [None]:
def collate_double(batch):
    """collate function for the ObjectDetectionDataSet.
    Only used by the dataloader.

    Credit: https://johschmidt42.medium.com/train-your-own-object-detector-with-faster-rcnn-pytorch-8d3c759cfc70
    """
    x = [sample['x'] for sample in batch]
    y = [sample['y'] for sample in batch]
    x_name = [sample['x_name'] for sample in batch]
    y_name = [sample['y_name'] for sample in batch]
    return x, y, x_name, y_name


dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_double
)

transform = GeneralizedRCNNTransform(min_size=1280,
                                     max_size=1280,
                                     image_mean=[0.485, 0.456, 0.406],
                                     image_std=[0.229, 0.224, 0.225])