# 1 Imports

In [1]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch

from abc import ABC, abstractmethod
from dataclasses import dataclass
from os import PathLike
import os

# Constants

In [2]:
BASE_DIR = os.path.abspath("")

BASE_DIR

'D:\\AgriNet-Research\\agrinet\\datasets\\Agriculture-Vision-2021'

In [3]:
dataset_name = "_dataset_mini"
dataset_path = os.path.join(BASE_DIR, dataset_name)
dataset_train_path = os.path.join(dataset_path, "train")
dataset_val_path = os.path.join(dataset_path, "val")
dataset_test_path = os.path.join(dataset_path, "test")

IMAGE_DIRS = [
    os.path.join(dataset_train_path, "images", "rgbn"),
    os.path.join(dataset_train_path, "aug_images", "rgbn"),
]

# has folder foreach label
LABEL_DIRS = [
    os.path.join(dataset_train_path, "labels"),
    os.path.join(dataset_train_path, "aug_labels"),
]

print(dataset_path)

D:\AgriNet-Research\agrinet\datasets\Agriculture-Vision-2021\_dataset_mini


# Image classes

In [4]:
class ImageSource(ABC):
    @abstractmethod
    def load(self) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _load_image_from_numpy(arr: np.ndarray) -> torch.Tensor:
        return torch.from_numpy(arr).long()

    @classmethod
    def _load_image_from_path(cls, path: str | PathLike) -> torch.Tensor:
        arr = np.array(PILImage.open(path))
        return cls._load_image_from_numpy(arr)
    
    def __repr__(self) -> str:
        return self.__class__.__name__

class Image(ImageSource):
    pass

class Mask(ImageSource):
    pass
    
@dataclass
class ImageData:
    image_id: str
    image: Image
    mask: Mask

In [5]:
class RGBImagePlusNIR(Image):
    def __init__(self, rgb_path: str, nir_path: str):
        self.rgb_path = rgb_path
        self.nir_path = nir_path
        
    def load(self) -> torch.Tensor:
        rgb = self._load_image_from_path(self.rgb_path).float().permute(2, 0, 1)  # (3, H, W)
        nir = self._load_image_from_path(self.nir_path).float().unsqueeze(0)      # (1, H, W)
        return torch.cat([rgb, nir], dim=0)  # (4, H, W)


class RGBNImage(Image):
    def __init__(self, rgbn_path: str):
        self.rgbn_path = rgbn_path
        
    def load(self) -> torch.Tensor:
        return self._load_image_from_path(self.rgbn_path).float().permute(2, 0, 1)

In [6]:
class OneHotMask(Mask):
    def __init__(self, *mask_paths: str):
        self._mask_paths = mask_paths
        self._labels_count = len(mask_paths)
        
    def load(self) -> torch.Tensor:
        tensors = [
            self._load_image_from_path(path).unsqueeze(0)   # add channel dimension
            for path in self._mask_paths
        ]
        return torch.cat(tensors, dim=0)  # (C, H, W)
    
    def __str__(self) -> str:
        cls_name = self.__class__.__name__
        labels_count = self._labels_count
        return f"{cls_name}(labels={labels_count})"

class IndexMask(Mask):
    def __init__(self, mask_path: str):
        self.mask_path = mask_path
        
    def load(self) -> torch.Tensor:
        return self._load_image_from_path(self.mask_path)  # (H, W)

# ImageIdsParser

In [7]:
class ImageIdsParser:
    @classmethod
    def get_ids_for_dirs(cls, dirs: list[str]) -> list[str]:
        ids = []
        [ids.extend(cls.get_ids_for_dir(d)) for d in dirs]
        return ids
    
    @classmethod
    def get_ids_for_dir(cls, path: str) -> list[str]:
        ids_with_nones = [cls._get_id_from_image_path(p) for p in cls._get_items_by_path(path)]
        ids = [i for i in ids_with_nones if i]
        return ids
    
    @classmethod
    def _get_id_from_image_path(cls, path: str) -> str | None:
        try:
            return path.split(".")[0]  # id-with-coords.png
        except IndexError:
            return None
    
    @classmethod
    def _get_items_by_path(cls, path: str) -> list[str]:
        try:
            return os.listdir(path)
        except FileNotFoundError:
            print(f"[WARNING] Path not found: {path}")
            return []

In [8]:
IMAGE_IDS = ImageIdsParser.get_ids_for_dirs(IMAGE_DIRS)



# Dataset Abstraction

In [9]:
class SegmentationDataset(ABC, Dataset):
    def __init__(self, image_ids: list[str]) -> None: 
        self._image_ids = image_ids

    def __len__(self) -> int:
        return len(self._image_ids)
        
    def __getitem__(self, idx: int) -> ImageData:
        image_id = self._image_ids[idx]
        return self.get_data(image_id)

    def first(self) -> ImageData | None:
        try:
            first_id = self._image_ids[0]
            return self.get_data(first_id)
        except IndexError:
            return None
    
    @abstractmethod
    def get_data(self, image_id: str) -> ImageData:
        raise NotImplementedError

# File Searcher

In [10]:
class FileSearcher:
    def __init__(self, search_paths: list[str]):
        self.search_paths = search_paths

    def search(self, file_name: str) -> list[str]:
        found_paths = []
        
        for root_folder in self.search_paths:
            for dirpath, dirnames, filenames in os.walk(root_folder):
                for filename in filenames:
                    if filename.lower() == file_name.lower():
                        found_paths.append(os.path.join(dirpath, filename))
        
        return found_paths

# Datasets

In [11]:
class TrainDataset(SegmentationDataset):
    EXPECTED_MASKS_COUNT = 9
    
    _image_searcher = FileSearcher(IMAGE_DIRS)
    _mask_searcher = FileSearcher(LABEL_DIRS)
    
    def __init__(self, image_ids: list[str]) -> None: 
        super().__init__(image_ids=image_ids)
    
    def get_data(self, image_id: str) -> ImageData:
        return ImageData(
            image_id=image_id,
            image=self._get_image(image_id),
            mask=self._get_mask(image_id),
        )
    
    def _get_image(self, image_id: str) -> Image:
        file_name = self._get_file_name(image_id)
        rgbn_path = self._image_searcher.search(file_name)
        return RGBNImage(rgbn_path)

    def _get_mask(self, image_id: str) -> Mask:
        file_name = self._get_file_name(image_id)
        masks = self._mask_searcher.search(file_name)
        self._validate_masks(masks)
        return OneHotMask(*masks)

    @classmethod
    def _validate_masks(cls, masks: list[str]) -> None:
        found = len(masks)
        expected = cls.EXPECTED_MASKS_COUNT
        if found != expected:
            raise Exception("Failed parse masks, {expected=}, {found=}")
    
    @staticmethod
    def _get_file_name(image_id: str) -> str:
        return f"{image_id}.png"


def _test() -> None:
    train_dataset = TrainDataset(image_ids=IMAGE_IDS)
    image_data = train_dataset.first()

    print(f"Loaded {len(train_dataset)} instances")
    
    if image_data:
        print(image_data)

_test()

Loaded 100 instances
ImageData(image_id='11IE4DKTR_11556-9586-12068-10098', image=RGBNImage, mask=OneHotMask)


# Build model with SegmentationModels.Pytorch

In [12]:
from segmentation_models_pytorch.metrics import iou_score, get_stats
import segmentation_models_pytorch as smp

In [13]:
model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights=None,
    in_channels=4, # RGBN
    classes=9,
    decoder_attention_type="scse"
)

In [14]:
def combined_loss(pred, target):
    return ce_loss(pred, target) + dice_loss(pred, target)

dice_loss = smp.losses.DiceLoss(mode='multiclass')
ce_loss = torch.nn.CrossEntropyLoss()

# Create DataLoader

In [15]:
from torch.utils.data import DataLoader

train_dataset = TrainDataset(image_ids=IMAGE_IDS)

# TODO parse val and test
val_dataset = train_dataset
test_dataset = train_dataset

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=9)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=9)

# Train Model

In [16]:
def calculate_metrics(outputs, masks, num_classes=9):
    preds = torch.argmax(outputs, dim=1)
    tp, fp, fn, tn = get_stats(preds, masks, mode='multiclass', num_classes=num_classes)
    iou = iou_score(tp, fp, fn, tn, reduction='micro')
    dice = (2 * tp.sum()) / (2 * tp.sum() + fp.sum() + fn.sum())
    return dice.item(), iou.item()

In [17]:
def train_one_epoch(model, loader, optimizer, scaler=None):
    model.train()
    total_loss = 0
    
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        
        if scaler:
            with autocast(device_type):
                outputs = model(imgs)
                loss = combined_loss(outputs, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss = combined_loss(outputs, masks)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item() * imgs.size(0)
    
    return total_loss / len(loader.dataset)


def validate_one_epoch(model, loader, num_classes=9):
    model.eval()
    total_loss = 0
    total_dice = 0
    total_iou = 0
    
    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = combined_loss(outputs, masks)
            
            total_loss += loss.item() * imgs.size(0)
            
            dice, iou = calculate_metrics(outputs, masks, num_classes=num_classes)
            total_dice += dice * imgs.size(0)
            total_iou += iou * imgs.size(0)
    
    dataset_size = len(loader.dataset)
    return (total_loss / dataset_size,
            total_dice / dataset_size,
            total_iou / dataset_size)

In [18]:
from torch.amp import GradScaler, autocast
from torch import nn
import torch

device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)

model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = GradScaler(device_type) if device_type == 'cuda' else None

print(f"Using device: {device_type}")

Using device: cpu


In [19]:
epochs = 0

for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, scaler)
    val_loss, val_dice, val_iou = validate_one_epoch(model, val_loader, num_classes=9)
    
    print(f"Epoch [{epoch+1}/{epochs}] "
          f"| Train Loss: {train_loss:.4f} "
          f"| Val Loss: {val_loss:.4f} "
          f"| Dice: {val_dice:.4f} "
          f"| IoU: {val_iou:.4f}")
    
    torch.save(model.state_dict(), f"unet_epoch_{epoch+1}.pth")