# 1 Imports

In [1]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch

from abc import ABC, abstractmethod
from dataclasses import dataclass
from os import PathLike
import os

# Constants

In [2]:
BASE_DIR = os.path.abspath("")

BASE_DIR

'D:\\AgriNet-Research\\agrinet\\datasets\\Agriculture-Vision-2021'

In [3]:
dataset_name = "_dataset"
dataset_path = os.path.join(BASE_DIR, dataset_name)
dataset_train_path = os.path.join(dataset_path, "train")
dataset_val_path = os.path.join(dataset_path, "val")
dataset_test_path = os.path.join(dataset_path, "test")

IMAGE_DIRS = [
    os.path.join(dataset_train_path, "images", "rgbn"),
    os.path.join(dataset_train_path, "aug_images", "rgbn"),
]

# has folder foreach label
LABEL_DIRS = [
    os.path.join(dataset_train_path, "labels"),
    os.path.join(dataset_train_path, "aug_labels"),
]

# Image classes

In [4]:
class ImageSource(ABC):
    @abstractmethod
    def load(self) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _load_image_from_numpy(arr: np.ndarray) -> torch.Tensor:
        return torch.from_numpy(arr).long()

    @classmethod
    def _load_image_from_path(cls, path: str | PathLike) -> torch.Tensor:
        arr = np.array(PILImage.open(path))
        return cls._load_image_from_numpy(arr)

class Image(ImageSource):
    pass

class Mask(ImageSource):
    pass

@dataclass
class ImageData:
    image_id: str
    image: Image
    mask: Mask

In [5]:
class RGBImagePlusNIR(Image):
    def __init__(self, rgb_path: str, nir_path: str):
        self.rgb_path = rgb_path
        self.nir_path = nir_path
        
    def load(self) -> torch.Tensor:
        rgb = self._load_image_from_path(self.rgb_path).float().permute(2, 0, 1)  # (3, H, W)
        nir = self._load_image_from_path(self.nir_path).float().unsqueeze(0)      # (1, H, W)
        return torch.cat([rgb, nir], dim=0)  # (4, H, W)


class RGBNImage(Image):
    def __init__(self, rgbn_path: str):
        self.rgbn_path = rgbn_path
        
    def load(self) -> torch.Tensor:
        return self._load_image_from_path(self.rgbn_path).float().permute(2, 0, 1)

In [6]:
class OneHotMask(Mask):
    def __init__(self, *mask_paths: str):
        self.mask_paths = mask_paths
        
    def load(self) -> torch.Tensor:
        tensors = [
            self._load_image_from_path(path).unsqueeze(0)   # add channel dimension
            for path in self.mask_paths
        ]
        return torch.cat(tensors, dim=0)  # (C, H, W)

class IndexMask(Mask):
    def __init__(self, mask_path: str):
        self.mask_path = mask_path
        
    def load(self) -> torch.Tensor:
        return self._load_image_from_path(self.mask_path)  # (H, W)

# ImageIdsParser

In [7]:
class ImageIdsParser:
    @classmethod
    def get_ids_for_dirs(cls, dirs: list[str]) -> list[str]:
        ids = []
        [ids.extend(cls.get_ids_for_dir(d)) for d in dirs]
        return ids
    
    @classmethod
    def get_ids_for_dir(cls, path: str) -> list[str]:
        ids_with_nones = [cls._get_id_from_image_path(p) for p in os.listdir(path)]
        ids = [i for i in ids_with_nones if i]
        return ids
    
    @classmethod
    def _get_id_from_image_path(cls, path: str) -> str | None:
        try:
            return path.split(".")[0]  # id-with-coords.png
        except Exception:
            return None

In [8]:
IMAGE_IDS = ImageIdsParser.get_ids_for_dirs(IMAGE_DIRS)

# Dataset Abstraction

In [9]:
class SegmentationDataset(ABC, Dataset):
    def __init__(self, image_ids: list[str]) -> None: 
        self._image_ids = image_ids

    def __len__(self) -> int:
        return len(self._image_ids)
        
    def __getitem__(self, idx: int) -> ImageData:
        image_id = self._image_ids[idx]
        return self.get_data(image_id)

    def first(self) -> ImageData | None:
        try:
            first_id = self._image_ids[0]
            return self.get_data(first_id)
        except Exception:
            return None
    
    @abstractmethod
    def get_data(self, image_id: str) -> ImageData:
        raise NotImplementedError

# Datasets

In [10]:
class TrainDataset(SegmentationDataset):
    def get_data(self, image_id: str) -> ImageData:
        return ImageData(
            image_id=image_id,
            image=self._get_image(image_id),
            mask=self._get_mask(image_id),
        )
    
    def _get_image(self, image_id: str) -> Image:
        print("TODO")
        return None

    def _get_mask(self, image_id: str) -> Mask:
        print("TODO")
        return None

def _test() -> None:
    train_dataset = TrainDataset(image_ids=IMAGE_IDS)
    image_data = train_dataset.first()

    print(f"Loaded {len(train_dataset)} instances")
    
    if image_data:
        print(image_data.image_id)

_test()

TODO
TODO
Loaded 156397 instances
11IE4DKTR_11556-9586-12068-10098


### TODO: should I add dataset for test and val?