# Road Line Segmentation Training

This notebook mirrors `simple_unet_training.py` and splits the code into logical sections for interactive execution.

In [25]:
import argparse
import random
import math
import sys
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as TF

IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}



In [26]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def resolve_device(device_str: str) -> torch.device:
    device_str = device_str.strip()
    if device_str.lower() == "auto":
        device_type = "cuda" if torch.cuda.is_available() else "cpu"
        return torch.device(device_type)
    try:
        device = torch.device(device_str)
    except (RuntimeError, ValueError) as exc:
        raise ValueError(f"Could not interpret device '{device_str}': {exc}") from exc
    if device.type == "cuda" and not torch.cuda.is_available():
        raise ValueError("CUDA requested via --device but no GPU is available.")
    return device



In [27]:
def find_image_mask_pairs(images_dir: Path, masks_dir: Path) -> List[Tuple[Path, Path]]:
    if not images_dir.exists() or not masks_dir.exists():
        raise FileNotFoundError(f"Missing directory: {images_dir} or {masks_dir}")
    image_paths = [
        p for p in images_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
    ]
    mask_paths = [
        p for p in masks_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
    ]
    mask_lookup: Dict[str, Path] = {}
    for mask_path in mask_paths:
        stem = mask_path.stem
        candidate_keys = {stem}
        for suffix in ("_mask", "-mask", "_label", "-label"):
            if stem.endswith(suffix):
                candidate_keys.add(stem[: -len(suffix)])
        for key in candidate_keys:
            mask_lookup.setdefault(key, mask_path)
    pairs: List[Tuple[Path, Path]] = []
    missing_images: List[str] = []
    for image_path in sorted(image_paths):
        base = image_path.stem
        candidates = [
            base,
            f"{base}_mask",
            f"{base}-mask",
            f"{base}_label",
            f"{base}-label",
        ]
        mask_path = None
        for candidate in candidates:
            mask_path = mask_lookup.get(candidate)
            if mask_path is not None:
                break
        if mask_path is None:
            missing_images.append(image_path.name)
            continue
        pairs.append((image_path, mask_path))
    if not pairs:
        raise RuntimeError(
            f"No image/mask pairs were found in {images_dir} and {masks_dir}"
        )
    if missing_images:
        preview = ", ".join(missing_images[:5])
        print(
            f"Warning: {len(missing_images)} image files did not have matching masks in {masks_dir}. "
            f"Examples: {preview}"
        )
    return pairs


def compute_dataset_statistics(
    pairs: Sequence[Tuple[Path, Path]],
    image_size: Tuple[int, int],
) -> Dict[str, object]:
    class_counts = np.zeros(2, dtype=np.int64)
    channel_sum = np.zeros(3, dtype=np.float64)
    channel_sq_sum = np.zeros(3, dtype=np.float64)
    total_pixels = 0
    positive_pixel_ratios: List[float] = []
    for image_path, mask_path in pairs:
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        image = TF.resize(image, image_size, interpolation=InterpolationMode.BILINEAR)
        mask = TF.resize(mask, image_size, interpolation=InterpolationMode.NEAREST)
        image_tensor = TF.to_tensor(image)
        flat = image_tensor.view(3, -1)
        channel_sum += flat.sum(dim=1).double().numpy()
        channel_sq_sum += (flat**2).sum(dim=1).double().numpy()
        total_pixels += flat.shape[1]
        mask_array = np.array(mask, dtype=np.uint8)
        binary_mask = (mask_array > 0).astype(np.int64)
        foreground = int(binary_mask.sum())
        background = int(binary_mask.size - foreground)
        class_counts[0] += background
        class_counts[1] += foreground
        positive_pixel_ratios.append(float(foreground / max(binary_mask.size, 1)))
    image_mean = (channel_sum / max(total_pixels, 1)).tolist()
    variance = channel_sq_sum / max(total_pixels, 1) - np.square(image_mean)
    image_std = np.sqrt(np.clip(variance, a_min=1e-6, a_max=None)).tolist()
    total_pixels_seen = int(class_counts.sum())
    if total_pixels_seen == 0:
        class_frequencies = [1.0, 0.0]
    else:
        class_frequencies = (class_counts / total_pixels_seen).tolist()
    positive_ratios_np = (
        np.array(positive_pixel_ratios, dtype=np.float64)
        if positive_pixel_ratios
        else np.array([0.0])
    )
    ratio_summary = {
        "mean": float(positive_ratios_np.mean()),
        "std": float(positive_ratios_np.std()),
        "min": float(positive_ratios_np.min()),
        "max": float(positive_ratios_np.max()),
        "p10": float(np.percentile(positive_ratios_np, 10)),
        "p50": float(np.percentile(positive_ratios_np, 50)),
        "p90": float(np.percentile(positive_ratios_np, 90)),
    }
    hist_counts, hist_edges = np.histogram(
        positive_ratios_np, bins=np.linspace(0.0, 1.0, 11)
    )
    hist_bins_centers = ((hist_edges[:-1] + hist_edges[1:]) / 2).tolist()
    return {
        "num_samples": len(pairs),
        "image_mean": image_mean,
        "image_std": image_std,
        "class_counts": class_counts.tolist(),
        "class_frequencies": class_frequencies,
        "image_size": list(image_size),
        "positive_pixel_ratios": positive_pixel_ratios,
        "positive_pixel_ratio_summary": ratio_summary,
        "positive_pixel_ratio_histogram": {
            "bins": hist_bins_centers,
            "counts": hist_counts.tolist(),
        },
    }


def log_dataset_statistics(dataset_name: str, stats: Dict[str, object]) -> None:
    print(f"\n[{dataset_name}] samples: {stats['num_samples']}")
    print(f"[{dataset_name}] image mean: {stats['image_mean']}")
    print(f"[{dataset_name}] image std:  {stats['image_std']}")
    print(
        f"[{dataset_name}] class counts (background, foreground): {stats['class_counts']}"
    )
    print(f"[{dataset_name}] class frequencies: {stats['class_frequencies']}")
    ratio_summary = stats.get("positive_pixel_ratio_summary")
    if ratio_summary:
        print(
            f"[{dataset_name}] positive pixel ratio mean={ratio_summary['mean']:.4f} "
            f"std={ratio_summary['std']:.4f} "
            f"min={ratio_summary['min']:.4f} "
            f"p50={ratio_summary['p50']:.4f} "
            f"max={ratio_summary['max']:.4f}"
        )
    histogram = stats.get("positive_pixel_ratio_histogram")
    if histogram:
        bins = histogram["bins"]
        counts = histogram["counts"]
        formatted = ", ".join(
            f"{bin_center:.2f}:{count}" for bin_center, count in zip(bins, counts)
        )
        print(
            f"[{dataset_name}] positive pixel ratio histogram (bin:count) -> {formatted}"
        )



In [28]:
def apply_random_augmentations(
    image: Image.Image, mask: Image.Image
) -> Tuple[Image.Image, Image.Image]:
    if random.random() < 0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)
    if random.random() < 0.1:
        image = TF.vflip(image)
        mask = TF.vflip(mask)
    if random.random() < 0.3:
        angle = random.uniform(-10.0, 10.0)
        image = TF.rotate(
            image,
            angle,
            interpolation=InterpolationMode.BILINEAR,
            fill=(0, 0, 0),
        )
        mask = TF.rotate(
            mask,
            angle,
            interpolation=InterpolationMode.NEAREST,
            fill=0,
        )
    return image, mask


class RoadLineDataset(Dataset):
    def __init__(
        self,
        pairs: Sequence[Tuple[Path, Path]],
        image_size: Tuple[int, int],
        augment: bool,
        mean: Sequence[float],
        std: Sequence[float],
    ) -> None:
        self.pairs = list(pairs)
        self.image_size = image_size
        self.augment = augment
        self.mean = list(mean)
        self.std = list(std)

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_path, mask_path = self.pairs[index]
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        if self.augment:
            image, mask = apply_random_augmentations(image, mask)
        image = TF.resize(
            image, self.image_size, interpolation=InterpolationMode.BILINEAR
        )
        mask = TF.resize(mask, self.image_size, interpolation=InterpolationMode.NEAREST)
        image_tensor = TF.to_tensor(image)
        image_tensor = TF.normalize(image_tensor, mean=self.mean, std=self.std)
        mask_array = np.array(mask, dtype=np.uint8)
        binary_mask = (mask_array > 0).astype(np.int64)
        mask_tensor = torch.from_numpy(binary_mask).long()
        return image_tensor, mask_tensor



In [29]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class Down(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class Up(nn.Module):
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels + skip_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        diff_y = x2.size(2) - x1.size(2)
        diff_x = x2.size(3) - x1.size(3)
        if diff_y != 0 or diff_x != 0:
            x1 = F.pad(
                x1,
                [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2],
            )
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class UNet(nn.Module):
    def __init__(
        self, in_channels: int = 3, num_classes: int = 2, base_channels: int = 32
    ) -> None:
        super().__init__()
        self.inc = DoubleConv(in_channels, base_channels)
        self.down1 = Down(base_channels, base_channels * 2)
        self.down2 = Down(base_channels * 2, base_channels * 4)
        self.down3 = Down(base_channels * 4, base_channels * 8)
        self.down4 = Down(base_channels * 8, base_channels * 16)
        self.up1 = Up(base_channels * 16, base_channels * 8, base_channels * 8)
        self.up2 = Up(base_channels * 8, base_channels * 4, base_channels * 4)
        self.up3 = Up(base_channels * 4, base_channels * 2, base_channels * 2)
        self.up4 = Up(base_channels * 2, base_channels, base_channels)
        self.outc = OutConv(base_channels, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)



In [30]:
class TverskyLoss(nn.Module):
    def __init__(
        self,
        alpha: float = 0.7,
        beta: float = 0.3,
        smooth: float = 1.0,
        from_logits: bool = True,
    ) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
        self.from_logits = from_logits

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if logits.dim() == 4 and logits.size(1) > 1:
            probs = torch.softmax(logits, dim=1)[:, 1, ...]
        else:
            probs = (
                torch.sigmoid(logits.squeeze(1))
                if self.from_logits
                else logits.squeeze(1)
            )
        probs = probs.clamp(min=1e-7, max=1 - 1e-7)
        targets = targets.float()
        tp = (probs * targets).sum(dim=(1, 2))
        fp = (probs * (1.0 - targets)).sum(dim=(1, 2))
        fn = ((1.0 - probs) * targets).sum(dim=(1, 2))
        score = (tp + self.smooth) / (
            tp + self.alpha * fp + self.beta * fn + self.smooth
        )
        return 1.0 - score.mean()

class SegmentationCriterion(nn.Module):
    def __init__(
        self,
        class_counts: Sequence[float],
        smooth: float = 1.0,
        w_bg: float = 1.5,
        w_fg: float = 1.0,
        alpha: float = 0.7,
        beta: float = 0.3,
        lambda_ce: float = 1.0,
        lambda_tv: float = 1.0,
    ) -> None:
        super().__init__()
        counts = torch.tensor(class_counts, dtype=torch.float32).clamp_min(1.0)
        total = counts.sum()
        base_weights = total / (len(counts) * counts)
        custom_weights = torch.tensor([w_bg, w_fg], dtype=base_weights.dtype)
        if custom_weights.numel() != base_weights.numel():
            custom_weights = torch.ones_like(base_weights)
        combined_weights = base_weights * custom_weights
        combined_weights = combined_weights / combined_weights.mean()
        self.register_buffer("class_weights", combined_weights)
        self.smooth = smooth
        self.lambda_ce = lambda_ce
        self.lambda_tv = lambda_tv
        self.secondary_loss = TverskyLoss(
            alpha=alpha, beta=beta, smooth=smooth, from_logits=True
        )

    def forward(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        return_components: bool = False,
    ) -> torch.Tensor | Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        ce_loss = (
            F.cross_entropy(logits, targets, weight=self.class_weights)
            if self.lambda_ce > 0
            else logits.new_tensor(0.0)
        )
        secondary = self.secondary_loss(logits, targets)
        total_loss = self.lambda_ce * ce_loss + self.lambda_tv * secondary
        if return_components:
            return total_loss, {
                "ce": ce_loss.detach(),
                "tversky": secondary.detach(),
            }
        return total_loss

# === Phu luc / Optional: cac loss khac (tham khao, khong thuc thi) ===
_OPTIONAL_LOSSES_DOC = '''
class CrossEntropyDiceLoss(nn.Module):
    def __init__(self, class_counts: Sequence[float], smooth: float = 1.0) -> None:
        super().__init__()
        counts = torch.tensor(class_counts, dtype=torch.float32).clamp_min(1.0)
        total = counts.sum()
        class_weights = total / (len(counts) * counts)
        self.register_buffer("class_weights", class_weights)
        self.smooth = smooth

    def forward(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        return_components: bool = False,
    ) -> torch.Tensor | Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        ce = F.cross_entropy(logits, targets, weight=self.class_weights)
        probabilities = torch.softmax(logits, dim=1)
        one_hot = (
            F.one_hot(targets, num_classes=logits.shape[1]).permute(0, 3, 1, 2).float()
        )
        intersection = (probabilities * one_hot).sum(dim=(0, 2, 3))
        total = probabilities.sum(dim=(0, 2, 3)) + one_hot.sum(dim=(0, 2, 3))
        dice_loss_per_class = 1.0 - (
            (2.0 * intersection + self.smooth) / (total + self.smooth)
        )
        dice_loss = dice_loss_per_class.mean()
        total_loss = ce + dice_loss
        if return_components:
            return total_loss, {"ce": ce.detach(), "dice": dice_loss.detach()}
        return total_loss

class FocalTverskyLoss(nn.Module):
    def __init__(
        self,
        alpha: float = 0.7,
        beta: float = 0.3,
        gamma: float = 1.33,
        smooth: float = 1.0,
        from_logits: bool = True,
    ) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth
        self.from_logits = from_logits
        self.base_loss = TverskyLoss(
            alpha=alpha, beta=beta, smooth=smooth, from_logits=from_logits
        )

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        base_value = self.base_loss(logits, targets)
        return torch.pow(base_value.clamp(min=1e-8), self.gamma)
'''


In [31]:
# def calculate_iou(
#     logits: torch.Tensor, targets: torch.Tensor, num_classes: int = 2
# ) -> torch.Tensor:
#     predictions = torch.argmax(logits, dim=1)
#     if targets.ndim == 4:
#         ground_truth = torch.argmax(targets, dim=1)
#     else:
#         ground_truth = targets
#     ious: List[torch.Tensor] = []
#     for cls in range(num_classes):
#         pred_mask = predictions == cls
#         gt_mask = ground_truth == cls
#         intersection = (pred_mask & gt_mask).sum().float()
#         union = pred_mask.sum() + gt_mask.sum() - intersection
#         if union == 0:
#             continue
#         ious.append((intersection + 1e-6) / (union + 1e-6))
#     if not ious:
#         return torch.tensor(0.0, device=logits.device)
#     return torch.stack(ious).mean()
#
#
# def calculate_pixel_accuracy(
#     logits: torch.Tensor, targets: torch.Tensor
# ) -> torch.Tensor:
#     predictions = torch.argmax(logits, dim=1)
#     if targets.ndim == 4:
#         ground_truth = torch.argmax(targets, dim=1)
#     else:
#         ground_truth = targets
#     correct = (predictions == ground_truth).sum().float()
#     total = ground_truth.numel()
#     return correct / max(total, 1)
#
#
def _extract_boundary(mask: torch.Tensor) -> torch.Tensor:
    mask = mask.float().unsqueeze(1)
    inv_mask = 1.0 - mask
    dilated_inv = F.max_pool2d(inv_mask, kernel_size=3, stride=1, padding=1)
    eroded = 1.0 - dilated_inv
    boundary = (mask - eroded).clamp(min=0.0)
    return boundary.squeeze(1)


def _dilate_mask(mask: torch.Tensor, radius: int) -> torch.Tensor:
    if radius <= 0:
        return mask
    kernel_size = radius * 2 + 1
    return F.max_pool2d(
        mask.unsqueeze(1), kernel_size=kernel_size, stride=1, padding=radius
    ).squeeze(1)


def compute_boundary_counts(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    tolerance_ratio: float = 0.02,
) -> Tuple[float, float, float, float]:
    if predictions.numel() == 0:
        return 0.0, 0.0, 0.0, 0.0
    preds = predictions.detach().to(dtype=torch.float32)
    gts = targets.detach().to(dtype=torch.float32)
    pred_boundary = _extract_boundary(preds)
    gt_boundary = _extract_boundary(gts)
    height, width = preds.shape[-2:]
    diag = math.sqrt(height**2 + width**2)
    radius = max(1, int(round(tolerance_ratio * diag)))
    pred_dilated = _dilate_mask(pred_boundary, radius)
    gt_dilated = _dilate_mask(gt_boundary, radius)
    matches_pred = (pred_boundary * gt_dilated).sum().item()
    matches_gt = (gt_boundary * pred_dilated).sum().item()
    pred_boundary_pixels = pred_boundary.sum().item()
    gt_boundary_pixels = gt_boundary.sum().item()
    return matches_pred, matches_gt, pred_boundary_pixels, gt_boundary_pixels



In [32]:
class MetricsAccumulator:
    def __init__(
        self, collect_probabilities: bool = False, boundary_tolerance: float = 0.02
    ) -> None:
        self.collect_probabilities = collect_probabilities
        self.boundary_tolerance = boundary_tolerance
        self.reset()

    def reset(self) -> None:
        self.tp = 0.0
        self.fp = 0.0
        self.fn = 0.0
        self.tn = 0.0
        self.matches_pred = 0.0
        self.matches_gt = 0.0
        self.pred_boundary_pixels = 0.0
        self.gt_boundary_pixels = 0.0
        self.pixel_count = 0.0
        if self.collect_probabilities:
            self.probabilities: List[torch.Tensor] = []
            self.labels: List[torch.Tensor] = []

    def update(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        threshold: float | None = None,
    ) -> None:
        with torch.no_grad():
            probability_fg: torch.Tensor | None = None
            if logits.dim() == 4 and logits.size(1) > 1:
                probs = torch.softmax(logits, dim=1)
                if logits.size(1) == 2 and threshold is not None:
                    probability_fg = probs[:, 1, ...]
                    preds = (probability_fg >= threshold).long()
                else:
                    preds = torch.argmax(probs, dim=1)
                    if probs.size(1) > 1:
                        probability_fg = probs[:, 1, ...]
            else:
                logits_flat = logits if logits.dim() == 3 else logits.squeeze(1)
                probability_fg = torch.sigmoid(logits_flat)
                if threshold is None:
                    preds = (probability_fg >= 0.5).long()
                else:
                    preds = (probability_fg >= threshold).long()
            fg_pred = preds == 1
            fg_true = targets == 1
            bg_true = targets == 0
            tp = (fg_pred & fg_true).sum().item()
            fp = (fg_pred & ~fg_true).sum().item()
            fn = (~fg_pred & fg_true).sum().item()
            tn = ((preds == 0) & bg_true).sum().item()
            self.tp += tp
            self.fp += fp
            self.fn += fn
            self.tn += tn
            self.pixel_count += preds.numel()
            matches_pred, matches_gt, pred_boundary_pixels, gt_boundary_pixels = (
                compute_boundary_counts(
                    fg_pred.float(),
                    fg_true.float(),
                    tolerance_ratio=self.boundary_tolerance,
                )
            )
            self.matches_pred += matches_pred
            self.matches_gt += matches_gt
            self.pred_boundary_pixels += pred_boundary_pixels
            self.gt_boundary_pixels += gt_boundary_pixels
            if self.collect_probabilities:
                if probability_fg is not None:
                    self.probabilities.append(probability_fg.detach().cpu().reshape(-1))
                    self.labels.append(fg_true.detach().cpu().reshape(-1))

    def _safe_div(self, numerator: float, denominator: float) -> float:
        if denominator <= 0.0:
            return 0.0
        return numerator / denominator

    def compute(self) -> Dict[str, float | Dict[str, float]]:
        precision = self._safe_div(self.tp, self.tp + self.fp)
        recall = self._safe_div(self.tp, self.tp + self.fn)
        f1 = self._safe_div(2 * precision * recall, precision + recall)
        iou = self._safe_div(self.tp, self.tp + self.fp + self.fn)
        dice = self._safe_div(2 * self.tp, 2 * self.tp + self.fp + self.fn)
        pixel_accuracy = self._safe_div(self.tp + self.tn, self.pixel_count)
        boundary_precision = self._safe_div(
            self.matches_pred, self.pred_boundary_pixels
        )
        boundary_recall = self._safe_div(self.matches_gt, self.gt_boundary_pixels)
        boundary_f1 = self._safe_div(
            2 * boundary_precision * boundary_recall,
            boundary_precision + boundary_recall,
        )
        confusion_matrix = {
            "tp": int(self.tp),
            "fp": int(self.fp),
            "fn": int(self.fn),
            "tn": int(self.tn),
        }
        return {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "foreground_iou": iou,
            "foreground_dice": dice,
            "pixel_accuracy": pixel_accuracy,
            "boundary_precision": boundary_precision,
            "boundary_recall": boundary_recall,
            "boundary_f1": boundary_f1,
            "confusion_matrix": confusion_matrix,
        }

    def pr_curve(
        self, num_thresholds: int = 11
    ) -> List[Tuple[float, float, float, float]]:
        if not self.collect_probabilities or not self.probabilities:
            return []
        scores = torch.cat(self.probabilities)
        labels = torch.cat(self.labels).float()
        thresholds = torch.linspace(0.0, 1.0, steps=num_thresholds)
        curve: List[Tuple[float, float, float, float]] = []
        for threshold in thresholds:
            preds = (scores >= threshold).to(labels.dtype)
            tp = (preds * labels).sum().item()
            fp = (preds * (1 - labels)).sum().item()
            fn = ((1 - preds) * labels).sum().item()
            precision = 0.0 if (tp + fp) == 0 else tp / (tp + fp)
            recall = 0.0 if (tp + fn) == 0 else tp / (tp + fn)
            f1 = (
                0.0
                if (precision + recall) == 0
                else (2 * precision * recall) / (precision + recall)
            )
            curve.append((float(threshold), precision, recall, f1))
        return curve


class EarlyStopping:
    def __init__(self, patience: int, min_delta: float = 0.0) -> None:
        self.patience = patience
        self.min_delta = min_delta
        self.best_score: float | None = None
        self.counter = 0

    def step(self, value: float) -> bool:
        if self.best_score is None or value > self.best_score + self.min_delta:
            self.best_score = value
            self.counter = 0
            return False
        self.counter += 1
        return self.counter >= self.patience



In [33]:
def train_and_validate(args: argparse.Namespace) -> None:
    device = resolve_device(args.device)
    image_size = tuple(args.image_size)
    train_pairs = find_image_mask_pairs(
        Path(args.data_root) / "train" / "images",
        Path(args.data_root) / "train" / "masks",
    )
    val_pairs = find_image_mask_pairs(
        Path(args.data_root) / "valid" / "images",
        Path(args.data_root) / "valid" / "masks",
    )
    train_stats = compute_dataset_statistics(train_pairs, image_size)
    val_stats = compute_dataset_statistics(val_pairs, image_size)
    log_dataset_statistics("train", train_stats)
    log_dataset_statistics("valid", val_stats)
    train_dataset = RoadLineDataset(
        pairs=train_pairs,
        image_size=image_size,
        augment=True,
        mean=train_stats["image_mean"],
        std=train_stats["image_std"],
    )
    val_dataset = RoadLineDataset(
        pairs=val_pairs,
        image_size=image_size,
        augment=False,
        mean=train_stats["image_mean"],
        std=train_stats["image_std"],
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    model = UNet(in_channels=3, num_classes=2, base_channels=args.base_channels).to(
        device
    )
    criterion = SegmentationCriterion(
        class_counts=train_stats["class_counts"],
        w_bg=args.w_bg,
        w_fg=args.w_fg,
        alpha=args.tv_alpha,
        beta=args.tv_beta,
        lambda_ce=args.lambda_ce,
        lambda_tv=args.lambda_tv,
    ).to(device)
    optimizer = Adam(
        model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay
    )
    scheduler = ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=3, verbose=True
    )
    early_stopper = EarlyStopping(patience=args.early_stopping_patience, min_delta=1e-4)
    best_iou = 0.0
    best_checkpoint: Dict[str, object] | None = None
    for epoch in range(1, args.epochs + 1):
        model.train()
        running_loss = 0.0
        train_ce_sum = 0.0
        train_secondary_sum = 0.0
        train_samples = 0
        train_metrics = MetricsAccumulator(collect_probabilities=False)
        for images, masks in train_loader:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            optimizer.zero_grad()
            logits = model(images)
            loss, components = criterion(logits, masks, return_components=True)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            train_ce_sum += components["ce"].item() * images.size(0)
            train_secondary_sum += components["tversky"].item() * images.size(0)
            train_samples += images.size(0)
            train_metrics.update(logits.detach(), masks, threshold=args.val_threshold)
        train_loss = running_loss / max(train_samples, 1)
        train_ce_loss = train_ce_sum / max(train_samples, 1)
        train_secondary_loss = train_secondary_sum / max(train_samples, 1)
        train_summary = train_metrics.compute()
        model.eval()
        val_loss_accum = 0.0
        val_ce_sum = 0.0
        val_secondary_sum = 0.0
        val_samples = 0
        val_metrics = MetricsAccumulator(collect_probabilities=True)
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)
                logits = model(images)
                loss, components = criterion(logits, masks, return_components=True)
                val_loss_accum += loss.item() * images.size(0)
                val_ce_sum += components["ce"].item() * images.size(0)
                val_secondary_sum += components["tversky"].item() * images.size(0)
                val_samples += images.size(0)
                val_metrics.update(logits, masks, threshold=args.val_threshold)
        val_loss = val_loss_accum / max(val_samples, 1)
        val_ce_loss = val_ce_sum / max(val_samples, 1)
        val_secondary_loss = val_secondary_sum / max(val_samples, 1)
        val_summary = val_metrics.compute()
        val_iou = float(val_summary["foreground_iou"])
        val_accuracy = float(val_summary["pixel_accuracy"])
        val_pr_curve = val_metrics.pr_curve(num_thresholds=21)
        val_pr_display = val_pr_curve[::5] if len(val_pr_curve) > 0 else []
        current_lr = optimizer.param_groups[0]["lr"]
        scheduler.step(val_iou)
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss: {train_loss:.4f} (ce {train_ce_loss:.4f}, tversky {train_secondary_loss:.4f}) | "
            f"train_fg_iou: {train_summary['foreground_iou']:.4f} | "
            f"train_fg_dice: {train_summary['foreground_dice']:.4f} | "
            f"train_prec: {train_summary['precision']:.4f} | "
            f"train_rec: {train_summary['recall']:.4f} | "
            f"val_loss: {val_loss:.4f} (ce {val_ce_loss:.4f}, tversky {val_secondary_loss:.4f}) | "
            f"val_fg_iou: {val_summary['foreground_iou']:.4f}@thr={args.val_threshold:.2f} | "
            f"val_fg_dice: {val_summary['foreground_dice']:.4f} | "
            f"val_prec: {val_summary['precision']:.4f} | "
            f"val_rec: {val_summary['recall']:.4f} | "
            f"val_boundary_f1: {val_summary['boundary_f1']:.4f} | "
            f"lr: {current_lr:.2e}"
        )
        train_conf = train_summary["confusion_matrix"]
        val_conf = val_summary["confusion_matrix"]
        print(
            f"    Train confusion (px): TP={train_conf['tp']} FP={train_conf['fp']} "
            f"FN={train_conf['fn']} TN={train_conf['tn']}"
        )
        print(
            f"    Val   confusion (px): TP={val_conf['tp']} FP={val_conf['fp']} "
            f"FN={val_conf['fn']} TN={val_conf['tn']} | "
            f"Boundary P/R={val_summary['boundary_precision']:.4f}/{val_summary['boundary_recall']:.4f}"
        )
        if val_pr_display:
            formatted_curve = ", ".join(
                f"{thr:.2f}:{prec:.2f}/{rec:.2f}/{f1:.2f}"
                for thr, prec, rec, f1 in val_pr_display
            )
            print(f"    Val PR curve (thr:prec/rec/f1) -> {formatted_curve}")
        if val_iou > best_iou + 1e-4:
            best_iou = val_iou
            val_loss_components = {"ce": val_ce_loss, "tversky": val_secondary_loss}
            train_loss_components = {"ce": train_ce_loss, "tversky": train_secondary_loss}
            best_checkpoint = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
                "val_loss_components": val_loss_components,
                "val_iou": val_iou,
                "val_accuracy": val_accuracy,
                "val_metrics": val_summary,
                "val_pr_curve": val_pr_curve,
                "train_loss": train_loss,
                "train_loss_components": train_loss_components,
                "train_metrics": train_summary,
                "training_args": vars(args),
                "train_stats": train_stats,
            }
            output_dir = Path(args.output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
            checkpoint_path = output_dir / "best_unet.pth"
            torch.save(best_checkpoint, checkpoint_path)
            print(f"Saved new best model to {checkpoint_path} (IoU={val_iou:.4f})")
        if early_stopper.step(val_iou):
            print(f"Early stopping triggered at epoch {epoch}")
            break
    if best_checkpoint is not None:
        model.load_state_dict(best_checkpoint["model_state_dict"])  # type: ignore[arg-type]
        print(
            f"Best validation IoU: {best_checkpoint['val_iou']:.4f} "
            f"(epoch {best_checkpoint['epoch']})"
        )
    else:
        print(
            "Training finished without improving validation IoU; no checkpoint saved."
        )





In [34]:

def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Simplified U-Net pipeline for road line segmentation."
    )
    parser.add_argument(
        "--data-root",
        "--data_root",
        "--data_dir",
        dest="data_root",
        type=str,
        default="dataset",
        help="Root directory containing train/valid/test folders.",
    )
    parser.add_argument(
        "--image-size",
        "--image_size",
        dest="image_size",
        type=int,
        nargs=2,
        default=(256, 256),
        metavar=("HEIGHT", "WIDTH"),
    )
    parser.add_argument(
        "--batch-size", "--batch_size", dest="batch_size", type=int, default=4
    )
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument(
        "--learning-rate",
        "--learning_rate",
        dest="learning_rate",
        type=float,
        default=1e-3,
    )
    parser.add_argument(
        "--weight-decay",
        "--weight_decay",
        dest="weight_decay",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--base-channels", "--base_channels", dest="base_channels", type=int, default=32
    )
    parser.add_argument(
        "--num-workers",
        "--num_workers",
        dest="num_workers",
        type=int,
        default=0,
        help="Set >0 if torch dataloader workers are available.",
    )
    parser.add_argument(
        "--early-stopping-patience",
        "--early_stopping_patience",
        dest="early_stopping_patience",
        type=int,
        default=10,
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--output-dir",
        "--output_dir",
        dest="output_dir",
        type=str,
        default="Model",
        help="Directory to store checkpoints.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="auto",
        help="Device to use ('auto', 'cpu', 'cuda', 'cuda:0', ...).",
    )
    # Loss function is fixed to 'ce+tversky'; see appendix for other variants.
    parser.add_argument(
        "--w-bg",
        dest="w_bg",
        type=float,
        default=1.5,
        help="Background weight multiplier for CE component.",
    )
    parser.add_argument(
        "--w-fg",
        dest="w_fg",
        type=float,
        default=1.0,
        help="Foreground weight multiplier for CE component.",
    )
    parser.add_argument(
        "--tv-alpha",
        dest="tv_alpha",
        type=float,
        default=0.7,
        help="Alpha coefficient for Tversky-based losses (FP penalty).",
    )
    parser.add_argument(
        "--tv-beta",
        dest="tv_beta",
        type=float,
        default=0.3,
        help="Beta coefficient for Tversky-based losses (FN penalty).",
    )
    parser.add_argument(
        "--lambda-ce",
        dest="lambda_ce",
        type=float,
        default=1.0,
        help="Weight for cross-entropy component.",
    )
    parser.add_argument(
        "--lambda-tv",
        dest="lambda_tv",
        type=float,
        default=1.0,
        help="Weight for the Tversky component.",
    )
    parser.add_argument(
        "--val-threshold",
        "--val_threshold",
        dest="val_threshold",
        type=float,
        default=0.75,
        help="Probability threshold for validation metrics.",
    )

    arguments: List[str]
    if argv is None:
        arguments = list(sys.argv[1:])
    else:
        arguments = list(argv)

    cleaned_arguments: List[str] = []
    skip_next = False
    for arg in arguments:
        if skip_next:
            skip_next = False
            continue
        # Filter out notebook-injected kernel arguments so argparse works inside Jupyter.
        if arg in {"--f", "-f"}:
            skip_next = True
            continue
        if arg.startswith("--f=") or arg.startswith("-f="):
            continue
        cleaned_arguments.append(arg)

    return parser.parse_args(cleaned_arguments)



In [None]:
from pathlib import Path

def resolve_dataset_root(dataset_dir: str = "dataset") -> Path:
    search_roots = [Path.cwd()]
    if "__file__" in globals():
        search_roots.append(Path(__file__).resolve().parent)
    for root in search_roots:
        for base in (root, *root.parents):
            candidate = base / dataset_dir
            if candidate.exists():
                return candidate
    raise FileNotFoundError(
        f"Unable to locate '{dataset_dir}' directory from {Path.cwd()}"
    )

data_root_path = resolve_dataset_root()

notebook_args = [
    "--data_root", str(data_root_path),
    "--epochs", "30",
    "--batch_size", "4",
    "--device", "cuda",
    "--learning_rate", "1e-3",
    "--weight_decay", "1e-5",
    "--output_dir", "Model",
    "--val_threshold", "0.75",
]

args = parse_args(notebook_args)


In [None]:
train_and_validate(args)


[train] samples: 751
[train] image mean: [0.4312999436802934, 0.44875362059089063, 0.4239335453879341]
[train] image std:  [0.23804450725234275, 0.23794269857501038, 0.24852641466394798]
[train] class counts (background, foreground): [48278140, 939396]
[train] class frequencies: [0.9809133882687666, 0.019086611731233355]
[train] positive pixel ratio mean=0.0191 std=0.0214 min=0.0000 p50=0.0131 max=0.2426
[train] positive pixel ratio histogram (bin:count) -> 0.05:743, 0.15:7, 0.25:1, 0.35:0, 0.45:0, 0.55:0, 0.65:0, 0.75:0, 0.85:0, 0.95:0

[valid] samples: 216
[valid] image mean: [0.4364930471681334, 0.45583993368954573, 0.42995701616423]
[valid] image std:  [0.2397509437434158, 0.23891895409026215, 0.25044919758865736]
[valid] class counts (background, foreground): [13868481, 287295]
[valid] class frequencies: [0.9797047509087456, 0.02029524909125434]
[valid] positive pixel ratio mean=0.0203 std=0.0202 min=0.0012 p50=0.0133 max=0.1521
[valid] positive pixel ratio histogram (bin:count) 

KeyboardInterrupt: 