In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from pathlib import Path
from typing import List

from tqdm.notebook import tqdm
from PIL import Image
from skimage import io
import numpy as np
from matplotlib import pyplot as plt
import copy

from hierarchical_semantic_segmentation.model import U2Net_Hierarchical

# Dataset


In [None]:
class SegmentationDataset(Dataset):
    """Example dataset that yields (image, mask) tuples.

    * **image**  – (3, H, W) float32 tensor scaled to [0, 1].
    * **mask**   – (H, W) long tensor with class indices.

    The dataset must already encode background as a unique class index
    (e.g. 0). See `LEVELS` below for the mapping used here.
    """

    def __init__(self,
                root_path = 'hierarchical_semantic_segmentation/Pascal-part',
                images_folder = 'JPEGImages', 
                masks_folder='gt_masks', 
                split='train',
                num_classes=9):
        
        self.num_classes = num_classes
        self.resize_shape = (256, 256)
        self.root = Path(root_path)
        self.images_paths = list((self.root).glob(f"{images_folder}/*.jpg"))
        self.images_paths_dict={ path.name.split('.')[0]:path  for path in self.images_paths}

        self.masks_paths = list((self.root).glob(f"{masks_folder}/*.npy"))
        self.masks_paths_dict={ path.name.split('.')[0]:path  for path in self.masks_paths}

        self.image_transform = transforms.Compose([
             transforms.Resize(self.resize_shape, interpolation=Image.NEAREST),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
        ])

        if split=='train':
            with open(root_path+'/train_id.txt') as file:
                img_names=file.readlines()
                self.img_names = [line[:-1] for line in img_names]
            
            tmp_names = []
            for name in self.img_names:
                if name in list(self.images_paths_dict.keys()) and name in list(self.masks_paths_dict.keys()):
                    tmp_names.append(name)

            img_names=copy.copy(tmp_names)

        elif  split=='val':
            with open(root_path+'/val_id.txt') as file:
                img_names=file.readlines()
                self.img_names = [line[:-1] for line in img_names]
            
            tmp_names = []
            for name in self.img_names:
                if name in list(self.images_paths_dict.keys()) and name in list(self.masks_paths_dict.keys()):
                    tmp_names.append(name)

            img_names=copy.copy(tmp_names)
        else:
            raise Exception('error in split')

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_names = self.img_names[idx]

        img_path = self.images_paths_dict[img_names]
        image = self.image_transform(Image.open(img_path).convert("RGB"))

        masks_path = self.masks_paths_dict[img_names]
        mask = np.load(masks_path)
        mask = np.resize(mask, self.resize_shape)

        one_hot_mask = np.eye(self.num_classes)[mask] 
        one_hot_mask = np.moveaxis(one_hot_mask, -1, 0) 

        one_hot_mask = torch.as_tensor(one_hot_mask, dtype=torch.long)
        mask = torch.as_tensor(mask, dtype=torch.long)
        
        # print('\nid=',idx,'img_names',img_names, 'img_shape',image.shape,'mask_shape',mask.shape)
        return image, mask

In [None]:
list({1:0,2:1}.keys())

In [None]:

#path to dataset root
root_path = 'hierarchical_semantic_segmentation/Pascal-part'
#path to save checkpoints
save = "checkpoints"

epochs = 100
batch = 4
lr = 1e-4
device='cuda'


# Data
train_ds = SegmentationDataset(root_path, split="train")
val_ds = SegmentationDataset(root_path, split="val")

train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True)

list(train_loader)

In [None]:
root_path = 'hierarchical_semantic_segmentation/Pascal-part'
dataset = SegmentationDataset(root_path=root_path, split='train')
print(dataset.__len__())

plt.imshow(dataset[0][1])


In [None]:
plt.show()

# Losses and metrics


In [None]:
def _prepare_level_target(gt: torch.Tensor, cls_ids: list[int]) -> torch.Tensor:
    """Return **remapped** GT where *cls_ids* become {0, …, K‑1}; others→IGNORE."""
    remapped = torch.full_like(gt, 255)
    for new_idx, cid in enumerate(cls_ids):
        remapped[gt == cid] = new_idx
    return remapped


def hierarchical_losses(pred: torch.Tensor, gt: torch.Tensor) -> list[torch.Tensor]:
    """Compute CE loss at each hierarchy level.

    Args:
        pred: (B, 9, H, W) logits from U²‑Net (no background channel).
        gt:   (B, H, W) ground‑truth with class IDs incl. background.
    Returns:
        List with three scalar losses [loss0, loss1, loss2].
    """
    losses = []
    levels = [[1], [2, 3], [4, 5, 6, 7, 8, 9]]
    for cls_ids in levels:
        tgt = _prepare_level_target(gt, cls_ids)            # (B, H, W)
        logits = pred[:, cls_ids, :, :]                     # (B, K, H, W)
        losses.append(F.cross_entropy(logits, tgt, ignore_index=255))
    return losses


def compute_mIoU(pred: torch.Tensor, gt: torch.Tensor, cls_ids: list[int]) -> float:
    """Mean IoU for given *cls_ids*, ignoring background.

    Args:
        pred: (B, 9, H, W) logits.
        gt:   (B, H, W) labels.
    """
    with torch.no_grad():
        preds = pred.argmax(dim=1)  # (B, H, W)
        ious = []
        for cid in cls_ids:
            pred_mask = preds == cid
            true_mask = gt == cid
            intersection = (pred_mask & true_mask).sum().item()
            union = (pred_mask | true_mask).sum().item()
            if union == 0:
                continue  # class absent in both → skip from mean
            ious.append(intersection / union)
        return float(sum(ious) / max(len(ious), 1))

# Training

In [None]:
model = U2Net_Hierarchical(num_classes=9)
x = torch.randn(2, 3, 512, 512)

with torch.no_grad():
    y, sides = model(x)

print(y.shape)

In [None]:
def train_one_epoch(model, loader, optim, device):
    model.train()
    epoch_loss = 0.0
    for imgs, gts in tqdm(loader, desc="train", leave=False):
        imgs, gts = imgs.to(device, non_blocking=True), gts.to(device, non_blocking=True)
        logits, _ = model(imgs)

        l0, l1, l2 = hierarchical_losses(logits, gts)
        loss = l0 + l1 + l2

        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()

        epoch_loss += loss.item() * imgs.size(0)
    return epoch_loss / len(loader.dataset)


def evaluate(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    tot = {k: 0.0 for k in range(3)}
    levels = [[1], [2, 3], [4, 5, 6, 7, 8, 9]]
    with torch.no_grad():
        for imgs, gts in tqdm(loader, desc="eval", leave=False):
            imgs, gts = imgs.to(device, non_blocking=True), gts.to(device, non_blocking=True)
            logits, _ = model(imgs)
            for lvl, cls_ids in enumerate(levels):
                tot[lvl] += compute_mIoU(logits, gts, cls_ids) * imgs.size(0)
    return {f"mIoU^{lvl}": tot[lvl] / len(loader.dataset) for lvl in range(3)}


#path to dataset root
root_path = 'hierarchical_semantic_segmentation/Pascal-part'
#path to save checkpoints
save = "checkpoints"

epochs = 100
batch = 4
lr = 1e-4
device='cuda'

# Model
model = U2Net_Hierarchical(num_classes=9).to(device)

# Data
train_ds = SegmentationDataset(root_path, split="train")
val_ds = SegmentationDataset(root_path, split="val")

train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True)

# Optimiser
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs)

save_dir = Path(save); save_dir.mkdir(parents=True, exist_ok=True)

for epoch in tqdm(range(1, epochs + 1)):
    train_loss = train_one_epoch(model, train_loader, optim, device)
    metrics = evaluate(model, val_loader, device)
    lr_sched.step()

    # Logging
    print(f"Epoch {epoch:03d} | loss={train_loss:.4f} | " + ", ".join(f"{k}={v:.3f}" for k, v in metrics.items()))

    # Save checkpoint
    torch.save({
        "model": model.state_dict(),
        "optim": optim.state_dict(),
        "epoch": epoch,
    }, save_dir / f"checkpoint_{epoch:03d}.pth")



In [None]:

#path to dataset root
root_path = 'hierarchical_semantic_segmentation/Pascal-part'
#path to save checkpoints
save = "checkpoints"

epochs = 100
batch = 4
lr = 1e-4
device='cuda'


# Data
train_ds = SegmentationDataset(root_path, split="train")
val_ds = SegmentationDataset(root_path, split="val")

train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True)

list(train_loader)