In [None]:
import os
import pathlib
from typing import Any

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image

In [None]:
datasets_dir = pathlib.Path("../datasets")

In [None]:
def read_annotations(path: pathlib.Path) -> tuple[list[str], list[int]]:
    """Reads dataset annotations from a CSV file.

    Args:
        path (pathlib.Path): Path to the CSV file.

    Returns:
        tuple[list[str], list[int]]: A tuple containing the labels and bounding boxes.
    """
    df = pd.read_csv(path, header=0)
    labels: list[str] = df["class"].tolist()
    boxes: list[int] = df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()
    return labels, boxes

In [None]:
for dataset_num in range(1, 6):
    dataset_dir = datasets_dir / f"dataset_{dataset_num:02}"
    annotations_csv = dataset_dir / "annotations.csv"
    annotations_pt = dataset_dir / "annotations.pt"

    labels, boxes = read_annotations(annotations_csv)

    targets = {
        "labels": labels,
        "boxes": boxes
    }

    torch.save(targets, annotations_pt)
    data = torch.load(annotations_pt)
    assert targets == data, f"Targets and loaded data are not equal for dataset {dataset_num:02}."

In [None]:
class OrbTrackingDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None) -> None:
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self) -> int:
        return len(self.img_labels)

    def __getitem__(self, idx) -> Any:
        img_path = os.path.join(
            self.img_dir, str(self.img_labels.iloc[idx, 0])
        )
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label