In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from pathlib import Path
from typing import List

from tqdm.notebook import tqdm


# Model

In [None]:

# -----------------------------------------------
#               Basic blocks
# -----------------------------------------------
class REBNCONV(nn.Module):
    """ReLU + BN + 3x3 Conv (used throughout U²-Net)."""
    def __init__(self, in_ch: int, out_ch: int, dilation: int = 1):
        super().__init__()
        padding = dilation
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# -----------------------------------------------
#          Residual U-blocks (RSU)
# -----------------------------------------------

def _upsample_like(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
    """Bilinear upsample `src` to the spatial size of `tgt`."""
    return F.interpolate(src, size=tgt.shape[2:], mode='bilinear', align_corners=False)


class _RSU_Base(nn.Module):
    """Base class for RSU blocks with variable depth."""
    def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
        super().__init__()
        self.height = height
        # initial conv
        self.rebn_in = REBNCONV(in_ch, out_ch)

        # encoder convs
        self.enc = nn.ModuleList()
        self.pool = nn.ModuleList()
        for _ in range(height - 1):
            self.enc.append(REBNCONV(out_ch if _ == 0 else mid_ch, mid_ch))
            self.pool.append(nn.MaxPool2d(2, stride=2, ceil_mode=True))

        # bottom conv (dilated)
        self.btm = REBNCONV(mid_ch, mid_ch, dilation=2)

        # decoder convs
        self.dec = nn.ModuleList()
        for _ in range(height - 1):
            self.dec.append(REBNCONV(mid_ch * 2, mid_ch))

        # final conv
        self.rebn_out = REBNCONV(mid_ch + out_ch, out_ch)

    def forward(self, x):
        x_in = self.rebn_in(x)

        # encoder
        enc_feats = []
        h = x_in
        for enc, pool in zip(self.enc, self.pool):
            h = enc(h)
            enc_feats.append(h)
            h = pool(h)

        # bottom
        h = self.btm(h)

        # decoder
        for idx, dec in enumerate(reversed(self.dec)):
            h = _upsample_like(h, enc_feats[-(idx + 1)])
            h = dec(torch.cat([h, enc_feats[-(idx + 1)]], dim=1))

        # output
        h = self.rebn_out(torch.cat([h, x_in], dim=1))
        return h + x_in  # residual


class RSU7(_RSU_Base):
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__(height=7, in_ch=in_ch, mid_ch=mid_ch, out_ch=out_ch)


class RSU6(_RSU_Base):
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__(height=6, in_ch=in_ch, mid_ch=mid_ch, out_ch=out_ch)


class RSU5(_RSU_Base):
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__(height=5, in_ch=in_ch, mid_ch=mid_ch, out_ch=out_ch)


class RSU4(_RSU_Base):
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__(height=4, in_ch=in_ch, mid_ch=mid_ch, out_ch=out_ch)


class RSU4F(nn.Module):
    """RSU4F: RSU block without pooling (all convolutions with dilation)."""
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__()
        self.rebn_in = REBNCONV(in_ch, out_ch)

        self.enc1 = REBNCONV(out_ch, mid_ch)
        self.enc2 = REBNCONV(mid_ch, mid_ch, dilation=2)
        self.enc3 = REBNCONV(mid_ch, mid_ch, dilation=4)

        self.dec1 = REBNCONV(mid_ch * 2, mid_ch, dilation=2)
        self.dec2 = REBNCONV(mid_ch * 2, mid_ch, dilation=1)

        self.rebn_out = REBNCONV(mid_ch + out_ch, out_ch)

    def forward(self, x):
        x_in = self.rebn_in(x)

        h1 = self.enc1(x_in)
        h2 = self.enc2(h1)
        h3 = self.enc3(h2)

        d1 = self.dec1(torch.cat([h3, h2], dim=1))
        d2 = self.dec2(torch.cat([d1, h1], dim=1))

        h = self.rebn_out(torch.cat([d2, x_in], dim=1))

        return h + x_in


# -----------------------------------------------
#                U²-Net Model
# -----------------------------------------------
class U2Net_Hierarchical(nn.Module):
    """
    U²-Net for hierarchical semantic segmentation.
    Args:
        num_classes (int): Number of output masks (channels) to predict.
    Input:
        RGB image tensor of shape (B, 3, H, W)
    Output:
        Tensor of shape (B, num_classes, H, W) -- hierarchical masks
    """
    def __init__(self, num_classes: int = 9, base_ch: int = 64):
        super().__init__()
        self.stage1 = RSU7(3, base_ch, base_ch)          # 3 -> 64
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(base_ch, base_ch, base_ch)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(base_ch, base_ch, base_ch)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(base_ch, base_ch, base_ch)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(base_ch, base_ch, base_ch)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(base_ch, base_ch, base_ch)

        # decoder
        self.stage5d = RSU4F(base_ch * 2, base_ch, base_ch)
        self.stage4d = RSU4(base_ch * 2, base_ch, base_ch)
        self.stage3d = RSU5(base_ch * 2, base_ch, base_ch)
        self.stage2d = RSU6(base_ch * 2, base_ch, base_ch)
        self.stage1d = RSU7(base_ch * 2, base_ch, base_ch)

        # side output convolutions (produce feature maps prior to final 1x1)
        self.side1 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)
        self.side2 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)
        self.side3 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)
        self.side4 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)
        self.side5 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)
        self.side6 = nn.Conv2d(base_ch, num_classes, kernel_size=3, padding=1)

        self.out_conv = nn.Conv2d(num_classes * 6, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        hx1 = self.stage1(x)    # (B, C, H, W)
        hx = self.pool12(hx1)

        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        hx6 = self.stage6(hx)

        # Decoder
        hx5d = self.stage5d(torch.cat([_upsample_like(hx6, hx5), hx5], dim=1))
        hx4d = self.stage4d(torch.cat([_upsample_like(hx5d, hx4), hx4], dim=1))
        hx3d = self.stage3d(torch.cat([_upsample_like(hx4d, hx3), hx3], dim=1))
        hx2d = self.stage2d(torch.cat([_upsample_like(hx3d, hx2), hx2], dim=1))
        hx1d = self.stage1d(torch.cat([_upsample_like(hx2d, hx1), hx1], dim=1))

        # Side outputs
        d1 = self.side1(hx1d)
        d2 = self.side2(hx2d)
        d3 = self.side3(hx3d)
        d4 = self.side4(hx4d)
        d5 = self.side5(hx5d)
        d6 = self.side6(hx6)

        d1 = _upsample_like(d1, x)
        d2 = _upsample_like(d2, x)
        d3 = _upsample_like(d3, x)
        d4 = _upsample_like(d4, x)
        d5 = _upsample_like(d5, x)
        d6 = _upsample_like(d6, x)

        # Fusion
        d0 = self.out_conv(torch.cat([d1, d2, d3, d4, d5, d6], dim=1))

        # Output is a tensor (B, num_classes, H, W). For compatibility we also
        # return the side outputs (each num_classes channels).
        return d0, (d1, d2, d3, d4, d5, d6)

In [3]:

model = U2Net_Hierarchical(num_classes=9)
x = torch.randn(2, 3, 256, 256)
with torch.no_grad():
    y, sides = model(x)
print(y.shape)  # Expected: (2, 9, 256, 256)

torch.Size([2, 9, 256, 256])


# Dataset


In [None]:
class SegmentationDataset(Dataset):
    """Example dataset that yields (image, mask) tuples.

    * **image**  – (3, H, W) float32 tensor scaled to [0, 1].
    * **mask**   – (H, W) long tensor with class indices.

    The dataset must already encode background as a unique class index
    (e.g. 0). See `LEVELS` below for the mapping used here.
    """

    def __init__(self, root: str | Path, split: str = "train") -> None:
        self.root = Path(root)
        self.split = split
        # TODO: implement your image/mask listing logic here
        self.items: List[Path] = sorted((self.root / split).glob("*.png"))

        self.image_tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int):
        img_path = self.items[idx]
        # Replace the following two lines with proper I/O according to your
        # dataset format (e.g. using PIL.Image or cv2).
        from PIL import Image
        image = self.image_tf(Image.open(img_path).convert("RGB"))
        mask = torch.as_tensor(Image.open(img_path.with_suffix(".label.png")), dtype=torch.long)
        return image, mask

###############################################################################
#                         Hierarchy definition                                #
###############################################################################

# Channel indices in model output / ground‑truth label IDs
LEVELS: list[list[int]] = [
    [1],                              # Level‑0 – body (ignore background 0)
    [2, 3],                           # Level‑1 – upper_body, lower_body
    [4, 5, 6, 7, 8, 9],               # Level‑2 – fine parts
]

BACKGROUND_ID = 0
IGNORE_INDEX = 255  # value used in CrossEntropyLoss to ignore pixels

def _prepare_level_target(gt: torch.Tensor, cls_ids: list[int]) -> torch.Tensor:
    """Return **remapped** GT where *cls_ids* become {0, …, K‑1}; others→IGNORE."""
    remapped = torch.full_like(gt, IGNORE_INDEX)
    for new_idx, cid in enumerate(cls_ids):
        remapped[gt == cid] = new_idx
    return remapped



# Losses and metrics


In [None]:


def hierarchical_losses(pred: torch.Tensor, gt: torch.Tensor) -> list[torch.Tensor]:
    """Compute CE loss at each hierarchy level.

    Args:
        pred: (B, 9, H, W) logits from U²‑Net (no background channel).
        gt:   (B, H, W) ground‑truth with class IDs incl. background.
    Returns:
        List with three scalar losses [loss0, loss1, loss2].
    """
    losses = []
    for cls_ids in LEVELS:
        tgt = _prepare_level_target(gt, cls_ids)            # (B, H, W)
        logits = pred[:, cls_ids, :, :]                     # (B, K, H, W)
        losses.append(F.cross_entropy(logits, tgt, ignore_index=IGNORE_INDEX))
    return losses


def compute_mIoU(pred: torch.Tensor, gt: torch.Tensor, cls_ids: list[int]) -> float:
    """Mean IoU for given *cls_ids*, ignoring background.

    Args:
        pred: (B, 9, H, W) logits.
        gt:   (B, H, W) labels.
    """
    with torch.no_grad():
        preds = pred.argmax(dim=1)  # (B, H, W)
        ious = []
        for cid in cls_ids:
            pred_mask = preds == cid
            true_mask = gt == cid
            intersection = (pred_mask & true_mask).sum().item()
            union = (pred_mask | true_mask).sum().item()
            if union == 0:
                continue  # class absent in both → skip from mean
            ious.append(intersection / union)
        return float(sum(ious) / max(len(ious), 1))



# Training

In [None]:
def train_one_epoch(model: nn.Module, loader: DataLoader, optim: torch.optim.Optimizer, device: torch.device):
    model.train()
    epoch_loss = 0.0
    for imgs, gts in tqdm(loader, desc="train", leave=False):
        imgs, gts = imgs.to(device, non_blocking=True), gts.to(device, non_blocking=True)
        logits, _ = model(imgs)

        l0, l1, l2 = hierarchical_losses(logits, gts)
        loss = l0 + l1 + l2

        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()

        epoch_loss += loss.item() * imgs.size(0)
    return epoch_loss / len(loader.dataset)


def evaluate(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    tot = {k: 0.0 for k in range(3)}
    with torch.no_grad():
        for imgs, gts in tqdm(loader, desc="eval", leave=False):
            imgs, gts = imgs.to(device, non_blocking=True), gts.to(device, non_blocking=True)
            logits, _ = model(imgs)
            for lvl, cls_ids in enumerate(LEVELS):
                tot[lvl] += compute_mIoU(logits, gts, cls_ids) * imgs.size(0)
    return {f"mIoU^{lvl}": tot[lvl] / len(loader.dataset) for lvl in range(3)}


parser.add_argument("--data", type=str, required=True, help="Path to dataset root")
parser.add_argument("--epochs", type=int, default=80)
parser.add_argument("--batch", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--save", type=str, default="checkpoints")
args = parser.parse_args()

device = torch.device(args.device if torch.cuda.is_available() else "cpu")

# Model
model = U2Net_Hierarchical(num_classes=9).to(device)

# Data
train_ds = SegmentationDataset(args.data, split="train")
val_ds = SegmentationDataset(args.data, split="val")

train_loader = DataLoader(train_ds, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch, shuffle=False, num_workers=4, pin_memory=True)

# Optimiser
optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

save_dir = Path(args.save); save_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(1, args.epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optim, device)
    metrics = evaluate(model, val_loader, device)
    lr_sched.step()

    # Logging
    print(f"Epoch {epoch:03d} | loss={train_loss:.4f} | " + ", ".join(f"{k}={v:.3f}" for k, v in metrics.items()))

    # Save checkpoint
    torch.save({
        "model": model.state_dict(),
        "optim": optim.state_dict(),
        "epoch": epoch,
    }, save_dir / f"checkpoint_{epoch:03d}.pth")

