# Kaggle End-to-End Pipeline

This notebook rebuilds the entire angiogram segmentation project inside the Kaggle runtime, installs all dependencies, preprocesses the dataset, trains UNet++, UNet 3+, and TransUNet, and produces evaluation plots. Attach the angiogram dataset (folder name must remain `Database_134_Angiograms`) when creating the Kaggle notebook, enable GPU + Internet, and run the cells in order.



## 0. Runtime checklist

- Kaggle accelerator: GPU (T4/P100) with Internet ON
- Data panel: add the angiogram dataset so `/kaggle/input/.../Database_134_Angiograms` exists
- No code upload neededâ€”this notebook reconstructs the repo inside `/kaggle/working/angiogram-segmentation`



In [None]:
from pathlib import Path

PROJECT_ROOT = Path("/kaggle/working/angiogram-segmentation")
PROJECT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Project root: {PROJECT_ROOT}")



In [None]:
files = {
    "requirements.txt": '''torch>=2.1.0
torchvision>=0.16.0
torchaudio>=2.1.0
numpy==1.26.4
pandas==2.2.2
scikit-learn==1.6.1
matplotlib==3.8.4
seaborn==0.13.2
albumentations==1.4.6
opencv-python==4.9.0.80
imageio==2.34.1
Pillow==10.4.0
tqdm==4.66.4
''',
    "src/__init__.py": '''"""
Core package for coronary angiogram segmentation project.
"""


''',
    "src/utils/__init__.py": '''"""
Utility helpers for environment-specific behaviour.
"""

from .env import resolve_data_dir

__all__ = ["resolve_data_dir"]

''',
    "src/utils/env.py": '''from __future__ import annotations

import os
from pathlib import Path


def resolve_data_dir(data_dir: str | Path) -> Path:
    """
    Resolve the dataset directory, with Kaggle auto-discovery fallback.

    Kaggle datasets are mounted read-only under /kaggle/input/<slug>/.
    When running inside that environment we search for a folder that matches
    the requested data directory name so users don't have to hard-code paths.
    """

    path = Path(data_dir)
    if path.exists():
        return path

    kaggle_root = Path("/kaggle/input")
    if "KAGGLE_KERNEL_RUN_TYPE" in os.environ and kaggle_root.exists():
        matches = sorted(kaggle_root.glob(f"**/{path.name}"))
        if matches:
            print(f"[INFO] Resolved data directory to {matches[0]} (Kaggle input).")
            return matches[0]

    raise FileNotFoundError(
        f"Could not locate data directory '{data_dir}'. "
        "Pass --data_dir with an existing path or attach the dataset in Kaggle."
    )
''',
    "src/data/__init__.py": '''"""
Data loading and preprocessing utilities.
"""

from .dataset import AngiogramSegmentationDataset, load_image_mask_pairs

__all__ = ["AngiogramSegmentationDataset", "load_image_mask_pairs"]


''',
    "src/data/dataset.py": '''"""
Dataset utilities for coronary angiogram segmentation.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import imageio.v2 as imageio
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

DEFAULT_IMAGE_SUFFIX = ".pgm"
DEFAULT_MASK_SUFFIX = "_gt.pgm"


def load_image_mask_pairs(
    data_dir: Path | str,
    image_suffix: str = DEFAULT_IMAGE_SUFFIX,
    mask_suffix: str = DEFAULT_MASK_SUFFIX,
) -> List[Tuple[Path, Path]]:
    """
    Discover angiogram image and mask pairs.

    Parameters
    ----------
    data_dir:
        Root directory containing image and mask files.
    image_suffix:
        Expected suffix for image filenames. Default: ``.pgm``.
    mask_suffix:
        Expected suffix appended before the file extension for the mask files.

    Returns
    -------
    list of tuple(Path, Path)
        Sorted list of image and mask paths.

    Raises
    ------
    FileNotFoundError
        If no pairs are found within ``data_dir``.
    """

    data_root = Path(data_dir)
    if not data_root.exists():
        raise FileNotFoundError(f"Data directory not found: {data_root}")

    image_paths: List[Path] = sorted(data_root.glob(f"*{image_suffix}"))
    pairs: List[Tuple[Path, Path]] = []
    for image_path in image_paths:
        stem = image_path.stem
        if stem.endswith("_gt"):
            # Skip mask files we may have picked up because of the suffix glob
            continue
        mask_path = data_root / f"{stem}{mask_suffix}"
        if mask_path.exists():
            pairs.append((image_path, mask_path))

    if not pairs:
        raise FileNotFoundError(
            f"No image/mask pairs found in {data_root}. "
            "Make sure the dataset is extracted correctly."
        )

    return pairs


@dataclass(frozen=True)
class DatasetSplits:
    """Container to hold absolute paths for dataset splits."""

    train: List[Tuple[Path, Path]]
    val: List[Tuple[Path, Path]]
    test: List[Tuple[Path, Path]]


def create_splits(
    pairs: Sequence[Tuple[Path, Path]],
    val_size: float = 0.15,
    test_size: float = 0.15,
    seed: int = 42,
) -> DatasetSplits:
    """
    Split dataset pairs into train/validation/test partitions.

    The test fraction is taken from the whole dataset, while the validation
    fraction is computed relative to the remaining training portion.

    Parameters
    ----------
    pairs:
        Sequence of image/mask path tuples.
    val_size:
        Fraction of the entire dataset to reserve for validation.
    test_size:
        Fraction of the entire dataset to reserve for testing.
    seed:
        Random seed for reproducibility.
    """

    if not 0.0 < val_size < 1.0:
        raise ValueError("val_size must be between 0 and 1")
    if not 0.0 < test_size < 1.0:
        raise ValueError("test_size must be between 0 and 1")
    if val_size + test_size >= 1.0:
        raise ValueError("val_size + test_size must be less than 1.0")

    train_pairs, test_pairs = train_test_split(
        pairs, test_size=test_size, random_state=seed, shuffle=True
    )
    # Adjust validation size relative to the remaining examples.
    relative_val_size = val_size / (1.0 - test_size)
    train_pairs, val_pairs = train_test_split(
        train_pairs,
        test_size=relative_val_size,
        random_state=seed,
        shuffle=True,
    )
    return DatasetSplits(
        train=list(train_pairs),
        val=list(val_pairs),
        test=list(test_pairs),
    )


class AngiogramSegmentationDataset(Dataset):
    """
    Torch ``Dataset`` for coronary angiogram segmentation.

    Each item returns a dict containing:

    ``image`` (torch.float32 tensor): Normalised image of shape (1, H, W)
    ``mask`` (torch.float32 tensor): Binary vessel mask of shape (1, H, W)
    ``path`` (str): Path to the original image.
    """

    def __init__(
        self,
        samples: Sequence[Tuple[Path, Path]],
        transform: Optional[Callable] = None,
        augment: Optional[Callable] = None,
        normalize: bool = True,
    ) -> None:
        self.samples = list(samples)
        if not self.samples:
            raise ValueError("Dataset received an empty list of samples.")
        self.transform = transform
        self.augment = augment
        self.normalize = normalize

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_path, mask_path = self.samples[index]
        image = imageio.imread(image_path).astype(np.float32)
        mask = imageio.imread(mask_path).astype(np.float32)

        if self.normalize:
            image = self._normalize_image(image)
        mask = (mask > 0).astype(np.float32)

        # Expand dims for albumentations (expects HWC)
        image = np.expand_dims(image, axis=-1)
        mask = np.expand_dims(mask, axis=-1)

        if self.augment is not None:
            aug_image = (image * 255.0).clip(0, 255).astype(np.uint8)
            aug_mask = (mask * 255.0).clip(0, 255).astype(np.uint8)
            augmented = self.augment(image=aug_image, mask=aug_mask)
            image = augmented["image"].astype(np.float32) / 255.0
            mask = (augmented["mask"] > 0).astype(np.float32)

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        image = np.transpose(image, (2, 0, 1))  # to CHW
        mask = np.transpose(mask, (2, 0, 1))

        return {
            "image": torch.from_numpy(image).float(),
            "mask": torch.from_numpy(mask).float(),
            "path": str(image_path),
        }

    @staticmethod
    def _normalize_image(image: np.ndarray) -> np.ndarray:
        min_val = image.min()
        max_val = image.max()
        if max_val - min_val < 1e-6:
            return np.zeros_like(image, dtype=np.float32)
        return (image - min_val) / (max_val - min_val)


def default_resize_transform(size: Tuple[int, int]) -> Callable:
    """Return an Albumentations resize transform to the given ``(height, width)``."""
    import albumentations as A

    return A.Compose(
        [
            A.Resize(height=size[0], width=size[1], interpolation=1),
        ]
    )


def default_augmentation_pipeline() -> Callable:
    """
    Construct a default albumentations augmentation pipeline.

    Includes rotations, flips, elastic transforms, and contrast adjustments.
    """

    import albumentations as A

    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.ShiftScaleRotate(
                shift_limit=0.0625,
                scale_limit=0.1,
                rotate_limit=25,
                border_mode=0,
                p=0.75,
            ),
            A.RandomBrightnessContrast(p=0.5),
            A.CLAHE(p=0.2),
            A.ElasticTransform(alpha=120, sigma=120 * 0.07, alpha_affine=10, p=0.2),
            A.GaussNoise(var_limit=(0.0, 0.001), p=0.2),
        ]
    )


def create_datasets(
    data_dir: Path | str,
    image_size: Tuple[int, int] = (512, 512),
    val_size: float = 0.15,
    test_size: float = 0.15,
    seed: int = 42,
    augment: bool = True,
) -> Tuple[AngiogramSegmentationDataset, AngiogramSegmentationDataset, AngiogramSegmentationDataset]:
    """
    High level helper returning train/val/test ``Dataset`` objects.
    """

    pairs = load_image_mask_pairs(data_dir)
    splits = create_splits(pairs, val_size=val_size, test_size=test_size, seed=seed)
    resize_transform = default_resize_transform(image_size)
    train_aug = default_augmentation_pipeline() if augment else None
    train_dataset = AngiogramSegmentationDataset(
        splits.train,
        transform=resize_transform,
        augment=train_aug,
    )
    val_dataset = AngiogramSegmentationDataset(
        splits.val,
        transform=resize_transform,
        augment=None,
    )
    test_dataset = AngiogramSegmentationDataset(
        splits.test,
        transform=resize_transform,
        augment=None,
    )
    return train_dataset, val_dataset, test_dataset


def create_dataloaders(
    datasets: Tuple[Dataset, Dataset, Dataset],
    batch_size: int = 4,
    num_workers: int = 0,
) -> Tuple[torch.utils.data.DataLoader, ...]:
    """
    Instantiate ``DataLoader`` objects for the provided datasets.
    """

    train_dataset, val_dataset, test_dataset = datasets
    return (
        torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
        ),
        torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
        ),
        torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
        ),
    )

''',
    "src/preprocessing.py": '''"""
Preprocessing and exploratory utilities for angiogram segmentation dataset.
"""

from __future__ import annotations

import random
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch

from src.data.dataset import (
    AngiogramSegmentationDataset,
    create_dataloaders,
    create_datasets,
    create_splits,
    load_image_mask_pairs,
)


def seed_everything(seed: int = 42) -> None:
    """
    Set random seeds for reproducibility across ``random``, ``numpy`` and ``torch``.
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def describe_dataset(pairs: Sequence[Tuple[Path, Path]]) -> dict:
    """
    Compute simple dataset statistics.

    Returns a dictionary containing total counts and basic image statistics.
    """

    import imageio.v2 as imageio

    stats = {
        "num_samples": len(pairs),
        "image_mean": 0.0,
        "image_std": 0.0,
        "mask_mean": 0.0,
    }
    if not pairs:
        return stats

    means: List[float] = []
    stds: List[float] = []
    mask_means: List[float] = []
    for image_path, mask_path in pairs:
        image = imageio.imread(image_path).astype(np.float32)
        mask = imageio.imread(mask_path).astype(np.float32)
        means.append(image.mean())
        stds.append(image.std())
        mask_means.append(mask.mean() / 255.0)

    stats["image_mean"] = float(np.mean(means))
    stats["image_std"] = float(np.mean(stds))
    stats["mask_mean"] = float(np.mean(mask_means))
    return stats


def visualize_samples(
    dataset: AngiogramSegmentationDataset,
    num_samples: int = 4,
    figsize: Tuple[int, int] = (12, 6),
    save_path: Path | None = None,
    show: bool = True,
) -> None:
    """
    Plot ``num_samples`` random images and masks from the dataset.

    Parameters
    ----------
    dataset:
        Dataset to sample from.
    num_samples:
        Number of examples to plot.
    figsize:
        Matplotlib figure size.
    save_path:
        If provided, save the figure to this path instead of (or in addition to)
        showing it interactively.
    show:
        Whether to call ``plt.show()``. Automatically disabled when running in
        non-interactive environments by setting ``show=False``.
    """

    indices = random.sample(range(len(dataset)), k=min(num_samples, len(dataset)))
    fig, axes = plt.subplots(len(indices), 2, figsize=figsize)
    if len(indices) == 1:
        axes = np.expand_dims(axes, axis=0)
    for row, idx in zip(axes, indices):
        sample = dataset[idx]
        image = sample["image"].squeeze().numpy()
        mask = sample["mask"].squeeze().numpy()
        row[0].imshow(image, cmap="gray")
        row[0].set_title("Image")
        row[0].axis("off")
        row[1].imshow(image, cmap="gray")
        row[1].imshow(mask, cmap="jet", alpha=0.4)
        row[1].set_title("Mask Overlay")
        row[1].axis("off")
    plt.tight_layout()
    if save_path is not None:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
    if show:
        plt.show()
    else:
        plt.close(fig)


__all__ = [
    "seed_everything",
    "describe_dataset",
    "visualize_samples",
    "load_image_mask_pairs",
    "create_splits",
    "create_datasets",
    "create_dataloaders",
]

''',
    "src/metrics.py": '''"""
Losses and metrics for segmentation models.
"""

from __future__ import annotations

from typing import Callable, Dict, Tuple

import torch
import torch.nn.functional as F


def dice_coefficient(
    preds: torch.Tensor, targets: torch.Tensor, epsilon: float = 1e-6
) -> torch.Tensor:
    preds = torch.sigmoid(preds)
    preds = preds.view(preds.size(0), -1)
    targets = targets.view(targets.size(0), -1)
    intersection = (preds * targets).sum(dim=1)
    union = preds.sum(dim=1) + targets.sum(dim=1)
    dice = (2 * intersection + epsilon) / (union + epsilon)
    return dice.mean()


def iou_score(
    preds: torch.Tensor, targets: torch.Tensor, epsilon: float = 1e-6
) -> torch.Tensor:
    preds = torch.sigmoid(preds)
    preds = preds.view(preds.size(0), -1)
    targets = targets.view(targets.size(0), -1)
    intersection = (preds * targets).sum(dim=1)
    total = preds.sum(dim=1) + targets.sum(dim=1)
    union = total - intersection
    iou = (intersection + epsilon) / (union + epsilon)
    return iou.mean()


def precision_recall(
    preds: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, epsilon: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
    preds = torch.sigmoid(preds)
    preds = (preds > threshold).float()

    tp = (preds * targets).sum(dim=[1, 2, 3])
    fp = (preds * (1 - targets)).sum(dim=[1, 2, 3])
    fn = ((1 - preds) * targets).sum(dim=[1, 2, 3])

    precision = (tp + epsilon) / (tp + fp + epsilon)
    recall = (tp + epsilon) / (tp + fn + epsilon)
    return precision.mean(), recall.mean()


def dice_loss(
    preds: torch.Tensor, targets: torch.Tensor, smooth: float = 1e-6
) -> torch.Tensor:
    preds = torch.sigmoid(preds)
    preds_flat = preds.contiguous().view(preds.size(0), -1)
    targets_flat = targets.contiguous().view(targets.size(0), -1)
    intersection = (preds_flat * targets_flat).sum(dim=1)
    union = preds_flat.sum(dim=1) + targets_flat.sum(dim=1)
    loss = 1 - ((2 * intersection + smooth) / (union + smooth))
    return loss.mean()


def bce_dice_loss(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    bce = F.binary_cross_entropy_with_logits(preds, targets)
    d_loss = dice_loss(preds, targets)
    return bce + d_loss


def focal_loss(
    preds: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 0.8,
    gamma: float = 2.0,
    reduction: str = "mean",
) -> torch.Tensor:
    preds_prob = torch.sigmoid(preds)
    ce_loss = F.binary_cross_entropy(preds_prob, targets, reduction="none")
    pt = torch.where(targets == 1, preds_prob, 1 - preds_prob)
    loss = ce_loss * ((1 - pt) ** gamma)
    if alpha >= 0:
        alpha_factor = torch.where(targets == 1, alpha, 1 - alpha)
        loss = alpha_factor * loss
    if reduction == "mean":
        return loss.mean()
    if reduction == "sum":
        return loss.sum()
    return loss


MetricFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


def compute_metrics(
    preds: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, torch.Tensor]:
    dice = dice_coefficient(preds, targets)
    iou = iou_score(preds, targets)
    precision, recall = precision_recall(preds, targets)
    return {
        "dice": dice,
        "iou": iou,
        "precision": precision,
        "recall": recall,
    }

''',
    "src/models/__init__.py": '''"""
Model architectures for coronary artery segmentation.
"""

from .unetpp import UNetPlusPlus
from .unet3plus import UNet3Plus
from .transunet import TransUNet

__all__ = ["UNetPlusPlus", "UNet3Plus", "TransUNet"]


''',
    "src/models/unetpp.py": '''"""
UNet++ implementation for coronary angiogram segmentation.
"""

from __future__ import annotations

from typing import Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0) -> None:
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if dropout > 0:
            layers.insert(3, nn.Dropout2d(dropout))
        self.block = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        return self.block(x)


class UNetPlusPlus(nn.Module):
    """
    Implementation of UNet++ (Nested U-Net) with optional deep supervision.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        filters: Sequence[int] = (32, 64, 128, 256, 512),
        deep_supervision: bool = False,
        up_mode: str = "bilinear",
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.deep_supervision = deep_supervision

        self.conv0_0 = ConvBlock(in_channels, filters[0], dropout=dropout)
        self.conv1_0 = ConvBlock(filters[0], filters[1], dropout=dropout)
        self.conv2_0 = ConvBlock(filters[1], filters[2], dropout=dropout)
        self.conv3_0 = ConvBlock(filters[2], filters[3], dropout=dropout)
        self.conv4_0 = ConvBlock(filters[3], filters[4], dropout=dropout)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv0_1 = ConvBlock(filters[0] + filters[1], filters[0])
        self.conv1_1 = ConvBlock(filters[1] + filters[2], filters[1])
        self.conv2_1 = ConvBlock(filters[2] + filters[3], filters[2])
        self.conv3_1 = ConvBlock(filters[3] + filters[4], filters[3])

        self.conv0_2 = ConvBlock(filters[0] * 2 + filters[1], filters[0])
        self.conv1_2 = ConvBlock(filters[1] * 2 + filters[2], filters[1])
        self.conv2_2 = ConvBlock(filters[2] * 2 + filters[3], filters[2])

        self.conv0_3 = ConvBlock(filters[0] * 3 + filters[1], filters[0])
        self.conv1_3 = ConvBlock(filters[1] * 3 + filters[2], filters[1])

        self.conv0_4 = ConvBlock(filters[0] * 4 + filters[1], filters[0])

        if deep_supervision:
            self.final1 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final2 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final3 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.final4 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
        else:
            self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
        x0_0 = self.conv0_0(x)  # down path
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self._upsample(x1_0, x0_0)], dim=1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self._upsample(x2_0, x1_0)], dim=1))
        x0_2 = self.conv0_2(
            torch.cat([x0_0, x0_1, self._upsample(x1_1, x0_0)], dim=1)
        )

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self._upsample(x3_0, x2_0)], dim=1))
        x1_2 = self.conv1_2(
            torch.cat([x1_0, x1_1, self._upsample(x2_1, x1_0)], dim=1)
        )
        x0_3 = self.conv0_3(
            torch.cat([x0_0, x0_1, x0_2, self._upsample(x1_2, x0_0)], dim=1)
        )

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self._upsample(x4_0, x3_0)], dim=1))
        x2_2 = self.conv2_2(
            torch.cat([x2_0, x2_1, self._upsample(x3_1, x2_0)], dim=1)
        )
        x1_3 = self.conv1_3(
            torch.cat([x1_0, x1_1, x1_2, self._upsample(x2_2, x1_0)], dim=1)
        )
        x0_4 = self.conv0_4(
            torch.cat([x0_0, x0_1, x0_2, x0_3, self._upsample(x1_3, x0_0)], dim=1)
        )

        if self.deep_supervision:
            outputs = (
                self.final1(x0_1),
                self.final2(x0_2),
                self.final3(x0_3),
                self.final4(x0_4),
            )
            return outputs
        return self.final(x0_4)

    @staticmethod
    def _upsample(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return F.interpolate(
            source, size=target.shape[2:], mode="bilinear", align_corners=True
        )

''',
    "src/models/unet3plus.py": '''"""
UNet 3+ implementation for coronary angiogram segmentation.
"""

from __future__ import annotations

from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        return self.block(x)


def _resize(
    tensor: torch.Tensor, target_shape: torch.Size, mode: str = "bilinear"
) -> torch.Tensor:
    if tensor.shape[-2:] == target_shape[-2:]:
        return tensor
    return F.interpolate(tensor, size=target_shape[-2:], mode=mode, align_corners=True)


class UNet3Plus(nn.Module):
    """
    UNet 3+ implementation with deep supervision support.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        filters: Sequence[int] = (32, 64, 128, 256, 512),
        cat_channels: int = 32,
        deep_supervision: bool = False,
    ) -> None:
        super().__init__()
        if len(filters) != 5:
            raise ValueError("UNet3Plus expects five filter values for the encoder.")
        self.deep_supervision = deep_supervision
        self.cat_channels = cat_channels

        self.encoder1 = ConvBlock(in_channels, filters[0])
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = ConvBlock(filters[0], filters[1])
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = ConvBlock(filters[1], filters[2])
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = ConvBlock(filters[2], filters[3])
        self.pool4 = nn.MaxPool2d(2)
        self.encoder5 = ConvBlock(filters[3], filters[4])

        self.h1_d4 = self._make_stage(filters[0], filters[0])
        self.h2_d4 = self._make_stage(filters[1], filters[0])
        self.h3_d4 = self._make_stage(filters[2], filters[0])
        self.h4_d4 = self._make_stage(filters[3], filters[0])
        self.h5_d4 = self._make_stage(filters[4], filters[0])

        self.h1_d3 = self._make_stage(filters[0], filters[1])
        self.h2_d3 = self._make_stage(filters[1], filters[1])
        self.h3_d3 = self._make_stage(filters[2], filters[1])
        self.h4_d3 = self._make_stage(filters[3], filters[1])
        self.h5_d3 = self._make_stage(filters[4], filters[1])

        self.h1_d2 = self._make_stage(filters[0], filters[2])
        self.h2_d2 = self._make_stage(filters[1], filters[2])
        self.h3_d2 = self._make_stage(filters[2], filters[2])
        self.h4_d2 = self._make_stage(filters[3], filters[2])
        self.h5_d2 = self._make_stage(filters[4], filters[2])

        self.h1_d1 = self._make_stage(filters[0], filters[3])
        self.h2_d1 = self._make_stage(filters[1], filters[3])
        self.h3_d1 = self._make_stage(filters[2], filters[3])
        self.h4_d1 = self._make_stage(filters[3], filters[3])
        self.h5_d1 = self._make_stage(filters[4], filters[3])

        self.h1_d0 = self._make_stage(filters[0], filters[4])
        self.h2_d0 = self._make_stage(filters[1], filters[4])
        self.h3_d0 = self._make_stage(filters[2], filters[4])
        self.h4_d0 = self._make_stage(filters[3], filters[4])
        self.h5_d0 = self._make_stage(filters[4], filters[4])

        concat_channels_1 = cat_channels * 5
        self.conv_d4 = ConvBlock(concat_channels_1, filters[3])
        self.conv_d3 = ConvBlock(concat_channels_1, filters[2])
        self.conv_d2 = ConvBlock(concat_channels_1, filters[1])
        self.conv_d1 = ConvBlock(concat_channels_1, filters[0])

        self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1)

        if deep_supervision:
            self.ds_out1 = nn.Conv2d(filters[0], out_channels, kernel_size=1)
            self.ds_out2 = nn.Conv2d(filters[1], out_channels, kernel_size=1)
            self.ds_out3 = nn.Conv2d(filters[2], out_channels, kernel_size=1)
            self.ds_out4 = nn.Conv2d(filters[3], out_channels, kernel_size=1)

    def _make_stage(self, in_channels: int, target_filters: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(in_channels, self.cat_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(self.cat_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]:
        h1 = self.encoder1(x)
        h2 = self.encoder2(self.pool1(h1))
        h3 = self.encoder3(self.pool2(h2))
        h4 = self.encoder4(self.pool3(h3))
        h5 = self.encoder5(self.pool4(h4))

        # Stage 4
        size4 = h4.size()
        d4_1 = torch.cat(
            [
                self.h1_d4(_resize(h1, size4)),
                self.h2_d4(_resize(h2, size4)),
                self.h3_d4(_resize(h3, size4)),
                self.h4_d4(h4),
                self.h5_d4(F.interpolate(h5, size=size4[-2:], mode="bilinear", align_corners=True)),
            ],
            dim=1,
        )
        d4 = self.conv_d4(d4_1)

        # Stage 3
        size3 = h3.size()
        d3_1 = torch.cat(
            [
                self.h1_d3(_resize(h1, size3)),
                self.h2_d3(_resize(h2, size3)),
                self.h3_d3(h3),
                self.h4_d3(_resize(h4, size3)),
                self.h5_d3(_resize(h5, size3)),
            ],
            dim=1,
        )
        d3 = self.conv_d3(d3_1)

        # Stage 2
        size2 = h2.size()
        d2_1 = torch.cat(
            [
                self.h1_d2(_resize(h1, size2)),
                self.h2_d2(h2),
                self.h3_d2(_resize(h3, size2)),
                self.h4_d2(_resize(h4, size2)),
                self.h5_d2(_resize(h5, size2)),
            ],
            dim=1,
        )
        d2 = self.conv_d2(d2_1)

        # Stage 1
        size1 = h1.size()
        d1_1 = torch.cat(
            [
                self.h1_d1(h1),
                self.h2_d1(_resize(h2, size1)),
                self.h3_d1(_resize(h3, size1)),
                self.h4_d1(_resize(h4, size1)),
                self.h5_d1(_resize(h5, size1)),
            ],
            dim=1,
        )
        d1 = self.conv_d1(d1_1)

        if self.deep_supervision:
            ds1 = self.ds_out1(d1)
            ds2 = self.ds_out2(d2)
            ds3 = self.ds_out3(d3)
            ds4 = self.ds_out4(d4)
            return ds1, ds2, ds3, ds4
        return self.final(d1)

''',
    "src/models/transunet.py": '''"""
Simplified TransUNet implementation combining CNN encoder and Transformer bottleneck.
"""

from __future__ import annotations

from typing import Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        return self.block(x)


class PatchEmbedding(nn.Module):
    def __init__(
        self,
        in_channels: int,
        embed_dim: int,
        patch_size: int,
    ) -> None:
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # B, N, C
        return x


class TransformerBottleneck(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        depth: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        return self.encoder(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv = ConvBlock(in_channels + skip_channels, out_channels)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=True)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


class TransUNet(nn.Module):
    """
    Transformer-based U-Net variant for angiogram segmentation.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        img_size: Tuple[int, int] = (512, 512),
        filters: Sequence[int] = (32, 64, 128, 256),
        embed_dim: int = 256,
        transformer_depth: int = 4,
        num_heads: int = 8,
        patch_size: int = 16,
    ) -> None:
        super().__init__()
        height, width = img_size
        if height % patch_size != 0 or width % patch_size != 0:
            raise ValueError("img_size must be divisible by patch_size.")

        self.img_size = img_size
        self.enc1 = ConvBlock(in_channels, filters[0])
        self.enc2 = ConvBlock(filters[0], filters[1])
        self.enc3 = ConvBlock(filters[1], filters[2])
        self.enc4 = ConvBlock(filters[2], filters[3])

        self.pool = nn.MaxPool2d(2)

        downscale = 2 ** (len(filters) - 1)
        bottleneck_hw = (height // downscale, width // downscale)
        patch_kernel = max(1, patch_size // downscale)
        if bottleneck_hw[0] % patch_kernel != 0 or bottleneck_hw[1] % patch_kernel != 0:
            raise ValueError(
                "Incompatible patch/kernel configuration. "
                "Ensure patch_size leads to integer number of tokens."
            )
        token_hw = (bottleneck_hw[0] // patch_kernel, bottleneck_hw[1] // patch_kernel)
        self.patch_embed = PatchEmbedding(filters[3], embed_dim, patch_size=patch_kernel)

        num_patches = token_hw[0] * token_hw[1]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.transformer = TransformerBottleneck(
            embed_dim=embed_dim,
            depth=transformer_depth,
            num_heads=num_heads,
        )
        self.proj_back = nn.Conv2d(embed_dim, filters[3], kernel_size=1)
        self._bottleneck_hw = bottleneck_hw
        self._token_hw = token_hw

        self.up1 = UpBlock(filters[3], filters[2], filters[2])
        self.up2 = UpBlock(filters[2], filters[1], filters[1])
        self.up3 = UpBlock(filters[1], filters[0], filters[0])

        self.final = nn.Conv2d(filters[0], out_channels, kernel_size=1)

        self._init_weights()

    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h1 = self.enc1(x)
        h2 = self.enc2(self.pool(h1))
        h3 = self.enc3(self.pool(h2))
        h4 = self.enc4(self.pool(h3))

        tokens = self.patch_embed(h4)  # B, N, C
        tokens = tokens + self.pos_embed[:, : tokens.size(1), :]
        tokens = self.transformer(tokens)

        b, _, c = tokens.shape
        tokens = tokens.transpose(1, 2).contiguous()
        tokens = tokens.view(b, c, self._token_hw[0], self._token_hw[1])
        tokens = F.interpolate(
            tokens,
            size=self._bottleneck_hw,
            mode="bilinear",
            align_corners=True,
        )
        bottleneck = self.proj_back(tokens)

        d3 = self.up1(bottleneck, h3)
        d2 = self.up2(d3, h2)
        d1 = self.up3(d2, h1)
        return self.final(d1)

''',
    "src/preprocess_dataset.py": '''"""
Command-line preprocessing utility to create persistent dataset splits
and optionally export resized numpy arrays.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import imageio.v2 as imageio
import numpy as np

from src.data.dataset import (
    AngiogramSegmentationDataset,
    create_splits,
    default_resize_transform,
    load_image_mask_pairs,
)
from src.preprocessing import seed_everything
from src.utils.env import resolve_data_dir


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Generate train/val/test splits and optionally export preprocessed arrays."
    )
    parser.add_argument(
        "--data_dir",
        type=Path,
        default=Path("Database_134_Angiograms"),
        help="Directory containing angiogram *.pgm files and *_gt.pgm masks.",
    )
    parser.add_argument(
        "--output_dir",
        type=Path,
        default=Path("results/preprocessed"),
        help="Directory where splits, manifests, and exported arrays will be stored.",
    )
    parser.add_argument(
        "--image_size",
        type=int,
        nargs=2,
        default=(512, 512),
        metavar=("HEIGHT", "WIDTH"),
        help="Target height and width for resizing images/masks.",
    )
    parser.add_argument(
        "--val_size",
        type=float,
        default=0.15,
        help="Fraction of data reserved for validation.",
    )
    parser.add_argument(
        "--test_size",
        type=float,
        default=0.15,
        help="Fraction of data reserved for testing.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed controlling the split.",
    )
    parser.add_argument(
        "--export",
        action="store_true",
        help="If set, export resized numpy arrays for each split.",
    )
    parser.add_argument(
        "--format",
        choices=("npy", "npz"),
        default="npy",
        help="File format used when exporting arrays.",
    )
    return parser.parse_args()


def normalize(image: np.ndarray) -> np.ndarray:
    min_val = image.min()
    max_val = image.max()
    if max_val - min_val < 1e-6:
        return np.zeros_like(image, dtype=np.float32)
    return (image - min_val) / (max_val - min_val)


def export_split(
    split_name: str,
    pairs: Iterable[Tuple[Path, Path]],
    resize_transform,
    output_dir: Path,
    file_format: str,
) -> List[Dict[str, str]]:
    manifest: List[Dict[str, str]] = []
    images_dir = output_dir / split_name / "images"
    masks_dir = output_dir / split_name / "masks"
    images_dir.mkdir(parents=True, exist_ok=True)
    masks_dir.mkdir(parents=True, exist_ok=True)

    for image_path, mask_path in pairs:
        image = imageio.imread(image_path).astype(np.float32)
        mask = imageio.imread(mask_path).astype(np.float32)

        image = normalize(image)
        mask = (mask > 0).astype(np.float32)

        image = np.expand_dims(image, axis=-1)
        mask = np.expand_dims(mask, axis=-1)
        if resize_transform is not None:
            transformed = resize_transform(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        image = np.transpose(image, (2, 0, 1))  # to CHW
        mask = np.transpose(mask, (2, 0, 1))

        stem = image_path.stem.replace("_gt", "")
        image_file = images_dir / f"{stem}.{file_format}"
        mask_file = masks_dir / f"{stem}.{file_format}"

        if file_format == "npy":
            np.save(image_file, image, allow_pickle=False)
            np.save(mask_file, mask, allow_pickle=False)
        else:
            np.savez_compressed(image_file, image=image)
            np.savez_compressed(mask_file, mask=mask)

        manifest.append(
            {
                "stem": stem,
                "original_image": str(image_path),
                "original_mask": str(mask_path),
                "image_file": str(image_file),
                "mask_file": str(mask_file),
            }
        )
    return manifest


def save_manifest(manifest: Dict[str, List[Dict[str, str]]], output_dir: Path) -> Path:
    manifest_path = output_dir / "manifest.json"
    with manifest_path.open("w") as f:
        json.dump(manifest, f, indent=2)
    return manifest_path


def save_splits(splits: Dict[str, List[Tuple[Path, Path]]], output_dir: Path) -> Path:
    serializable = {
        split: [
            {"image": str(image_path), "mask": str(mask_path)}
            for image_path, mask_path in pairs
        ]
        for split, pairs in splits.items()
    }
    splits_path = output_dir / "splits.json"
    with splits_path.open("w") as f:
        json.dump(serializable, f, indent=2)
    return splits_path


def main() -> None:
    args = parse_args()
    seed_everything(args.seed)
    output_dir = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    data_dir = resolve_data_dir(args.data_dir)
    print(f"Collecting image/mask pairs from {data_dir}...")
    pairs = load_image_mask_pairs(data_dir)
    splits = create_splits(
        pairs,
        val_size=args.val_size,
        test_size=args.test_size,
        seed=args.seed,
    )

    splits_dict = {
        "train": splits.train,
        "val": splits.val,
        "test": splits.test,
    }
    splits_path = save_splits(splits_dict, output_dir)
    print(f"Saved split definitions to {splits_path}")

    if not args.export:
        print("Export flag not set; skipping array export.")
        return

    resize_transform = default_resize_transform(tuple(args.image_size))
    manifest: Dict[str, List[Dict[str, str]]] = {}
    for split_name, split_pairs in splits_dict.items():
        print(f"Exporting {split_name} split with {len(split_pairs)} samples...")
        manifest[split_name] = export_split(
            split_name,
            split_pairs,
            resize_transform,
            output_dir=output_dir,
            file_format=args.format,
        )

    manifest_path = save_manifest(manifest, output_dir)
    print(f"Saved export manifest to {manifest_path}")


if __name__ == "__main__":
    main()

''',
    "src/train.py": '''"""
Training script for coronary angiogram segmentation models.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn as nn
from torch.cuda import amp
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from src.metrics import bce_dice_loss, compute_metrics, dice_loss, focal_loss
from src.models import TransUNet, UNet3Plus, UNetPlusPlus
from src.preprocessing import create_dataloaders, create_datasets, seed_everything
from src.utils.env import resolve_data_dir


LOSS_MAP = {
    "dice": dice_loss,
    "bce_dice": bce_dice_loss,
    "focal": focal_loss,
}

MODEL_MAP = {
    "unetpp": UNetPlusPlus,
    "unet3plus": UNet3Plus,
    "transunet": TransUNet,
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train segmentation models on angiograms.")
    parser.add_argument(
        "--data_dir",
        type=str,
        default="Database_134_Angiograms",
        help="Directory containing angiogram images and masks.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results",
        help="Directory to store checkpoints and metrics.",
    )
    parser.add_argument(
        "--image_size",
        type=int,
        nargs=2,
        default=(512, 512),
        help="Resize all images to this size (height width).",
    )
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=["unetpp", "unet3plus", "transunet"],
        choices=list(MODEL_MAP.keys()),
        help="Models to train.",
    )
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--loss", type=str, default="bce_dice", choices=list(LOSS_MAP.keys()))
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--patience", type=int, default=10, help="Early stopping patience.")
    parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision.")
    return parser.parse_args()


def prepare_output_dir(output_dir: Path) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "checkpoints").mkdir(exist_ok=True)
    (output_dir / "metrics").mkdir(exist_ok=True)


def instantiate_model(name: str, image_size: Tuple[int, int]) -> nn.Module:
    model_cls = MODEL_MAP[name]
    if name == "transunet":
        return model_cls(img_size=image_size)
    return model_cls()


def train_one_epoch(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: Optional[amp.GradScaler],
    device: torch.device,
) -> Tuple[float, Dict[str, float]]:
    model.train()
    running_loss = 0.0
    metrics_sum = {"dice": 0.0, "iou": 0.0, "precision": 0.0, "recall": 0.0}
    num_batches = 0

    for batch in tqdm(loader, desc="Train", leave=False):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)

        optimizer.zero_grad(set_to_none=True)
        with amp.autocast(enabled=scaler is not None):
            outputs = model(images)
            loss = criterion(outputs, masks)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            batch_metrics = compute_metrics(outputs.detach(), masks)
        running_loss += loss.item()
        for key in metrics_sum:
            metrics_sum[key] += float(batch_metrics[key])
        num_batches += 1

    avg_loss = running_loss / max(1, num_batches)
    avg_metrics = {k: v / max(1, num_batches) for k, v in metrics_sum.items()}
    return avg_loss, avg_metrics


def evaluate(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Tuple[float, Dict[str, float]]:
    model.eval()
    running_loss = 0.0
    metrics_sum = {"dice": 0.0, "iou": 0.0, "precision": 0.0, "recall": 0.0}
    num_batches = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval", leave=False):
            images = batch["image"].to(device)
            masks = batch["mask"].to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            batch_metrics = compute_metrics(outputs, masks)
            running_loss += loss.item()
            for key in metrics_sum:
                metrics_sum[key] += float(batch_metrics[key])
            num_batches += 1

    avg_loss = running_loss / max(1, num_batches)
    avg_metrics = {k: v / max(1, num_batches) for k, v in metrics_sum.items()}
    return avg_loss, avg_metrics


def fit_model(
    model_name: str,
    args: argparse.Namespace,
    datasets: Tuple[
        torch.utils.data.Dataset,
        torch.utils.data.Dataset,
        torch.utils.data.Dataset,
    ],
) -> Dict[str, float]:
    train_loader, val_loader, test_loader = create_dataloaders(
        datasets,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    device = torch.device(args.device)
    model = instantiate_model(model_name, tuple(args.image_size)).to(device)
    criterion = LOSS_MAP[args.loss]
    optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)
    scaler = amp.GradScaler() if args.amp and device.type == "cuda" else None

    best_val_dice = 0.0
    epochs_no_improve = 0
    history: List[Dict[str, float]] = []

    for epoch in range(1, args.epochs + 1):
        print()
        print(f"Epoch {epoch}/{args.epochs} - Model: {model_name}")
        train_loss, train_metrics = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler, device
        )
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)
        scheduler.step()

        log_entry = {
            "epoch": epoch,
            "train_loss": train_loss,
            **{f"train_{k}": v for k, v in train_metrics.items()},
            "val_loss": val_loss,
            **{f"val_{k}": v for k, v in val_metrics.items()},
        }
        history.append(log_entry)
        print(json.dumps(log_entry, indent=2))

        current_val_dice = val_metrics["dice"]
        if current_val_dice > best_val_dice + 1e-4:
            best_val_dice = current_val_dice
            epochs_no_improve = 0
            checkpoint_path = (
                Path(args.output_dir)
                / "checkpoints"
                / f"{model_name}_best.pth"
            )
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved new best checkpoint to {checkpoint_path}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= args.patience:
                print("Early stopping triggered.")
                break

    # Load best checkpoint for testing
    best_path = Path(args.output_dir) / "checkpoints" / f"{model_name}_best.pth"
    if best_path.exists():
        model.load_state_dict(torch.load(best_path, map_location=device))
        print(f"Loaded best checkpoint from {best_path}")

    test_loss, test_metrics = evaluate(model, test_loader, criterion, device)
    print("Test metrics:", test_metrics)

    metrics_path = (
        Path(args.output_dir) / "metrics" / f"{model_name}_metrics.json"
    )
    with metrics_path.open("w") as f:
        json.dump(
            {
                "history": history,
                "best_val_dice": best_val_dice,
                "test_loss": test_loss,
                "test_metrics": {k: float(v) for k, v in test_metrics.items()},
            },
            f,
            indent=2,
        )
    return {k: float(v) for k, v in test_metrics.items()}


def main() -> None:
    args = parse_args()
    seed_everything(args.seed)
    output_dir = Path(args.output_dir)
    prepare_output_dir(output_dir)

    data_dir = resolve_data_dir(args.data_dir)

    datasets = create_datasets(
        data_dir,
        image_size=tuple(args.image_size),
        val_size=0.15,
        test_size=0.15,
        seed=args.seed,
        augment=True,
    )

    summary: Dict[str, Dict[str, float]] = {}
    for model_name in args.models:
        metrics = fit_model(model_name, args, datasets)
        summary[model_name] = metrics

    summary_path = output_dir / "metrics" / "summary.json"
    with summary_path.open("w") as f:
        json.dump(summary, f, indent=2)
    print("Saved summary metrics to", summary_path)


if __name__ == "__main__":
    main()

''',
    "src/plot_results.py": '''"""
Utility to aggregate and plot model comparison metrics.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import pandas as pd


def load_metrics(metrics_dir: Path) -> pd.DataFrame:
    rows: List[Dict[str, float | str]] = []
    for metrics_file in metrics_dir.glob("*_metrics.json"):
        with metrics_file.open() as f:
            data = json.load(f)
        model_name = metrics_file.stem.replace("_metrics", "")
        test_metrics = data.get("test_metrics", {})
        rows.append(
            {
                "model": model_name,
                **test_metrics,
            }
        )
    if not rows:
        raise FileNotFoundError(f"No metric files found in {metrics_dir}")
    return pd.DataFrame(rows)


def plot_metrics(df: pd.DataFrame, output_dir: Path) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    for metric in ["dice", "iou", "precision", "recall"]:
        if metric not in df.columns:
            continue
        ax = df.plot(
            x="model",
            y=metric,
            kind="bar",
            legend=False,
            rot=0,
            title=f"{metric.upper()} comparison",
        )
        ax.set_ylabel(metric.UPPER())
        ax.set_ylim(0, 1.0)
        plt.tight_layout()
        figure_path = output_dir / f"{metric}_comparison.png"
        plt.savefig(figure_path)
        plt.close()
        print(f"Saved {metric} plot to {figure_path}")


def save_table(df: pd.DataFrame, output_dir: Path) -> None:
    csv_path = output_dir / "metrics_summary.csv"
    df.to_csv(csv_path, index=False)
    print(f"Saved metrics table to {csv_path}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Plot segmentation metric comparisons.")
    parser.add_argument(
        "--metrics_dir",
        type=str,
        default="results/metrics",
        help="Directory containing *_metrics.json files.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results/plots",
        help="Directory to store generated figures and tables.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    metrics_dir = Path(args.metrics_dir)
    df = load_metrics(metrics_dir)
    output_dir = Path(args.output_dir)
    save_table(df, output_dir)
    plot_metrics(df, output_dir)


if __name__ == "__main__":
    main()

'''
}

for relative_path, content in files.items():
    destination = PROJECT_ROOT / relative_path
    destination.parent.mkdir(parents=True, exist_ok=True)
    destination.write_text(content.rstrip("\n") + "\n", encoding="utf-8")

print(f"Wrote {len(files)} files under {PROJECT_ROOT}")



In [None]:
plot_file = PROJECT_ROOT / "src/plot_results.py"
text = plot_file.read_text(encoding="utf-8")
text = text.replace("metric.UPPER()", "metric.upper()")
plot_file.write_text(text, encoding="utf-8")
print("Patched plot_results.py")



In [None]:
%%capture
%cd /kaggle/working/angiogram-segmentation
!pip install -r requirements.txt



## 1. Configure dataset + hyperparameters
Update the values below if you renamed the attached dataset or want to tweak training behaviour. By default, the scripts auto-discover `Database_134_Angiograms` anywhere under `/kaggle/input`. Results are written to `/kaggle/working/results` so they persist when you save a Notebook version.



In [None]:
from dataclasses import dataclass

@dataclass
class TrainConfig:
    data_dir: str = "Database_134_Angiograms"
    output_dir: str = "/kaggle/working/results"
    image_size: tuple[int, int] = (512, 512)
    epochs: int = 100
    batch_size: int = 8
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    num_workers: int = 2
    models: tuple[str, ...] = ("unetpp", "unet3plus", "transunet")
    loss: str = "bce_dice"
    seed: int = 42

CFG = TrainConfig()
CFG



## 2. Preprocess + export optional `.npy` tensors
This step creates train/val/test splits and (optionally) exports resized arrays for faster experimentation. Skip `--export` if you only need split definitions.



In [None]:
import subprocess

preprocess_cmd = [
    "python",
    "-m",
    "src.preprocess_dataset",
    "--data_dir",
    CFG.data_dir,
    "--output_dir",
    str(Path(CFG.output_dir) / "preprocessed"),
    "--image_size",
    str(CFG.image_size[0]),
    str(CFG.image_size[1]),
    "--export",
    "--format",
    "npy",
]

print("Running:", " ".join(preprocess_cmd))
subprocess.run(preprocess_cmd, cwd=str(PROJECT_ROOT), check=True)



## 3. Train UNet++, UNet 3+, TransUNet
This cell sequentially trains all three architectures with shared hyperparameters from `CFG`. Adjust `CFG.models` if you only need a subset. Mixed precision (`--amp`) is enabled automatically when a GPU is available.



In [None]:
train_cmd = [
    "python",
    "-m",
    "src.train",
    "--data_dir",
    CFG.data_dir,
    "--output_dir",
    CFG.output_dir,
    "--image_size",
    str(CFG.image_size[0]),
    str(CFG.image_size[1]),
    "--epochs",
    str(CFG.epochs),
    "--batch_size",
    str(CFG.batch_size),
    "--learning_rate",
    str(CFG.learning_rate),
    "--weight_decay",
    str(CFG.weight_decay),
    "--num_workers",
    str(CFG.num_workers),
    "--loss",
    CFG.loss,
    "--seed",
    str(CFG.seed),
    "--models",
    *CFG.models,
    "--amp",
]

print("Running:", " ".join(train_cmd))
subprocess.run(train_cmd, cwd=str(PROJECT_ROOT), check=True)



## 4. Aggregate metrics + plots
Generates `metrics_summary.csv` plus per-metric bar charts under `/kaggle/working/results/plots`.



In [None]:
plot_cmd = [
    "python",
    "-m",
    "src.plot_results",
    "--metrics_dir",
    str(Path(CFG.output_dir) / "metrics"),
    "--output_dir",
    str(Path(CFG.output_dir) / "plots"),
]

print("Running:", " ".join(plot_cmd))
subprocess.run(plot_cmd, cwd=str(PROJECT_ROOT), check=True)



## 5. Quick glance at metrics + artifacts
The snippet below prints the summary JSON, shows the metrics table, and lists which files were generated so you can download or publish them as a Kaggle Dataset.



In [None]:
import json
import pandas as pd

summary_path = Path(CFG.output_dir) / "metrics" / "summary.json"
if summary_path.exists():
    summary = json.loads(summary_path.read_text())
    print("Summary metrics:\n", json.dumps(summary, indent=2))
else:
    print("Summary file not found:", summary_path)

summary_csv = Path(CFG.output_dir) / "plots" / "metrics_summary.csv"
if summary_csv.exists():
    display(pd.read_csv(summary_csv))
else:
    print("metrics_summary.csv not found yet")

print("\nArtifacts under", CFG.output_dir)
for path in sorted(Path(CFG.output_dir).rglob("*")):
    if path.is_file():
        print("-", path.relative_to(Path(CFG.output_dir)))

