# Final Project Notebook
This notebook implements the project pipeline in the intended execution order: **data → targets → model → loss/metrics → training → evaluation**. Each section below adds the corresponding building blocks, starting from data loading and augmentations and ending with training and evaluation placeholders.


## Imports & Seeds

In [None]:
import random
from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import solt
import solt.transforms as slt
from tqdm.auto import tqdm

SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


## Data Loading

In [None]:
class KneeSegmentationDataset(Dataset):
    """Slice-wise dataset with paired SOLT transforms."""

    def __init__(self, df: pd.DataFrame, transforms: solt.Stream | None):
        self.dataset = df.reset_index(drop=True)
        self.trf = transforms

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> dict:
        entry = self.dataset.iloc[idx]
        img = cv2.imread(str(entry.img), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(entry.segmask), cv2.IMREAD_GRAYSCALE)
        if img is None or mask is None:
            raise FileNotFoundError(f"Missing image/mask for index {idx}")

        if self.trf is not None:
            res = self.trf({"image": img, "mask": mask}, return_torch=False)
            img, mask = res.data

        img = img.astype(np.float32) / 255.0
        mask = mask.astype(np.int64)

        img_t = torch.from_numpy(img).permute(2, 0, 1)
        mask_t = torch.from_numpy(mask).unsqueeze(0)

        return {"image": img_t, "mask": mask_t}


class KneeSegmentationDatasetVol(Dataset):
    """Volume-wise dataset that groups slices by scan."""

    def __init__(self, df: pd.DataFrame, transforms: solt.Stream | None, load_mask: bool = True):
        super().__init__()
        self.df = df
        self.trf = transforms
        self.load_mask = load_mask

        self.scans = []
        for _, g in self.df.groupby(["ID", "SIDE", "VISIT"]):
            g = g.sort_values(by="slice_idx").reset_index(drop=True)
            self.scans.append(g)

    def __len__(self) -> int:
        return len(self.scans)

    def __getitem__(self, index: int) -> dict:
        return self.load_volume(self.scans[index], self.trf, load_mask=self.load_mask)

    @staticmethod
    def _to_chw_tensor(img: np.ndarray | torch.Tensor) -> torch.Tensor:
        if isinstance(img, np.ndarray):
            tensor = torch.from_numpy(img)
            if tensor.dim() == 3 and tensor.shape[-1] in (1, 3):
                tensor = tensor.permute(2, 0, 1)
            tensor = tensor.float() / 255.0
            return tensor

        if torch.is_tensor(img):
            tensor = img.float()
            if tensor.dim() == 3 and tensor.shape[-1] in (1, 3):
                tensor = tensor.permute(2, 0, 1)
            if tensor.max() > 1.0:
                tensor = tensor / 255.0
            return tensor

        raise TypeError("Unsupported image type for volume loading")

    @staticmethod
    def _to_hw_long(mask: np.ndarray | torch.Tensor) -> torch.Tensor:
        if isinstance(mask, np.ndarray):
            tensor = torch.from_numpy(mask)
        elif torch.is_tensor(mask):
            tensor = mask
        else:
            raise TypeError("Unsupported mask type for volume loading")

        if tensor.dim() == 3 and tensor.size(0) == 1:
            tensor = tensor.squeeze(0)
        return tensor.long()

    @staticmethod
    def load_volume(vol_df: pd.DataFrame, transform: solt.Stream | None, load_mask: bool = True) -> dict:
        images, masks = [], []
        for _, entry in vol_df.iterrows():
            img = cv2.imread(str(entry.img), cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Missing image: {entry.img}")
            images.append(img)

            if load_mask:
                mask = cv2.imread(str(entry.segmask), cv2.IMREAD_GRAYSCALE)
                if mask is None:
                    raise FileNotFoundError(f"Missing mask: {entry.segmask}")
                masks.append(mask)

        if transform is not None:
            if load_mask:
                res = transform({"images": images, "masks": masks})
            else:
                res = transform({"images": images})
            images = res["images"]
            if load_mask:
                masks = res["masks"]

        scan = torch.stack([KneeSegmentationDatasetVol._to_chw_tensor(img) for img in images], dim=0)
        if load_mask:
            mask = torch.stack([KneeSegmentationDatasetVol._to_hw_long(m) for m in masks], dim=0)
        else:
            mask = None

        return {"scan": scan, "mask": mask}


## Augmentation

In [None]:
train_trf = solt.Stream([
    slt.Resize((256, 256)),
    slt.Flip(p=0.5, axis=1),
    slt.Crop((224, 224), crop_mode="r"),
    slt.GammaCorrection(gamma_range=0.1, p=1),
])

val_trf = solt.Stream([
    slt.Resize((256, 256)),
])


## Config

In [None]:
@dataclass
class Cfg:
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    in_channels: int = 3
    num_classes: int = 2

    dataset_train_cls = KneeSegmentationDataset
    dataset_val_cls = KneeSegmentationDataset

    optimizer_cls = torch.optim.AdamW
    loss_fn = torch.nn.CrossEntropyLoss()

    lr: float = 1e-3
    wd: float = 5e-5
    num_epochs: int = 1
    n_workers: int = 0
    train_bs: int = 2
    val_bs: int = 2

    lr_drop_milestones: list[int] = None
    class_names: list[str] = None

    train_trf: solt.Stream | None = train_trf
    val_trf: solt.Stream | None = val_trf

    def __post_init__(self):
        if self.lr_drop_milestones is None:
            self.lr_drop_milestones = [30]
        if self.class_names is None:
            self.class_names = ["BG", "FG"]
        self.num_classes = len(self.class_names)


class CfgHighRes(Cfg):
    train_trf = solt.Stream([
        slt.Flip(p=0.5, axis=1),
        slt.Crop((320, 320), crop_mode="r"),
        slt.GammaCorrection(gamma_range=0.1, p=1),
    ])

    val_trf = solt.Stream([])


class CfgVol(Cfg):
    dataset_val_cls = KneeSegmentationDatasetVol
    val_bs = 1


## Trainer

### Subtask: Define INR regression model
This subtask introduces a minimal coordinate-based INR MLP that outputs a single continuous value per coordinate, without classification activations.

**Files to modify/create:**
- `project/final_project.ipynb`


In [None]:
class INRMLP(nn.Module):
    """Minimal coordinate-based INR for scalar regression."""

    def __init__(
        self,
        in_features: int = 2,
        hidden_features: int = 128,
        hidden_layers: int = 4,
        out_features: int = 1,
    ):
        super().__init__()
        layers = [nn.Linear(in_features, hidden_features), nn.ReLU(inplace=True)]
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(hidden_features, hidden_features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(hidden_features, out_features))
        self.net = nn.Sequential(*layers)

    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Args:
            coords: (..., 2) tensor of normalized (x, y) coordinates.
        Returns:
            (...,) tensor of scalar predictions per coordinate.
        """
        flat_coords = coords.view(-1, coords.shape[-1])
        out = self.net(flat_coords)
        return out.view(*coords.shape[:-1])


In [None]:
class SimpleSegmentationNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class BaseTrainer:
    def __init__(self, train_df: pd.DataFrame, val_df: pd.DataFrame, cfg: Cfg):
        self.train_df = train_df
        self.val_df = val_df
        self.cfg = cfg

        self.train_loader = None
        self.val_loader = None
        self.loss_fn = None
        self.optimizer = None
        self.model = None

    def init_model(self):
        raise NotImplementedError

    def init_run(self):
        self.init_model()

        train_ds = self.cfg.dataset_train_cls(self.train_df, self.cfg.train_trf)
        val_ds = self.cfg.dataset_val_cls(self.val_df, self.cfg.val_trf)

        self.train_loader = DataLoader(
            train_ds,
            batch_size=self.cfg.train_bs,
            shuffle=True,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )
        self.val_loader = DataLoader(
            val_ds,
            batch_size=self.cfg.val_bs,
            shuffle=False,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )

        self.optimizer = self.cfg.optimizer_cls(
            self.model.parameters(),
            lr=self.cfg.lr,
            weight_decay=self.cfg.wd,
        )
        self.loss_fn = self.cfg.loss_fn

    def adjust_lr(self):
        if self.cfg.lr_drop_milestones and self.epoch in self.cfg.lr_drop_milestones:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] *= 0.1

    def run(self, n_epochs: int | None = None):
        self.init_run()
        if n_epochs is None:
            n_epochs = self.cfg.num_epochs

        history = []
        for self.epoch in range(n_epochs):
            self.model.train()
            train_loss = self.train_epoch()
            self.model.eval()
            val_out = self.val_epoch()
            self.post_val_hook(train_loss, val_out)

            history.append({
                "epoch": self.epoch,
                "train_loss": float(train_loss),
                **val_out,
            })
        return history

    def train_epoch(self):
        pbar = tqdm(self.train_loader, desc=f"[{self.epoch}] Train")
        running_loss = 0.0
        self.adjust_lr()

        for i, batch in enumerate(pbar):
            self.optimizer.zero_grad()
            loss, _ = self.pass_batch(batch)
            loss.backward()
            self.optimizer.step()

            running_loss += float(loss.item())
            pbar.set_postfix({"loss": running_loss / (i + 1)})

        return running_loss / max(1, len(self.train_loader))

    def val_epoch(self):
        raise NotImplementedError

    def post_val_hook(self, train_loss, val_out):
        print("=" * 50)
        print(f"[{self.epoch}] --> Train loss: {train_loss:.4f}")
        print(f"[{self.epoch}] --> Val loss: {val_out['val_loss']:.4f}")
        print("=" * 50)

    def pass_batch(self, batch):
        img = batch["image"].to(self.cfg.device)
        mask = batch["mask"].to(self.cfg.device)

        logits = self.model(img)
        target = mask.squeeze(1).long()
        loss = self.loss_fn(logits, target)
        return loss, logits


class SegmentationTrainer2D(BaseTrainer):
    def init_model(self):
        self.model = SimpleSegmentationNet(
            in_channels=self.cfg.in_channels,
            num_classes=self.cfg.num_classes,
        ).to(self.cfg.device)

    @torch.no_grad()
    def val_epoch(self):
        running_loss = 0.0
        pbar = tqdm(self.val_loader, desc=f"[{self.epoch}] Val", leave=False)
        for i, batch in enumerate(pbar):
            loss, _ = self.pass_batch(batch)
            running_loss += float(loss.item())
            pbar.set_postfix({"loss": running_loss / (i + 1)})
        return {"val_loss": running_loss / max(1, len(self.val_loader))}


class SegmentationTrainer3D(SegmentationTrainer2D):
    def pass_eval_batch(self, batch, compute_loss: bool = False):
        scan = batch["scan"].to(self.cfg.device)
        target = batch.get("mask")
        if target is not None:
            target = target.to(self.cfg.device)

        batch_size, n_slices, _, height, width = scan.shape
        logits_vol = torch.zeros(
            batch_size, self.cfg.num_classes, n_slices, height, width,
            device=self.cfg.device,
        )

        total_loss = 0.0
        for slice_idx in range(n_slices):
            x_s = scan[:, slice_idx, ...]
            logits_s = self.model(x_s)
            logits_vol[:, :, slice_idx, :, :] = logits_s

            if compute_loss and target is not None:
                y_s = target[:, slice_idx, ...]
                loss_s = self.loss_fn(logits_s, y_s)
                total_loss += float(loss_s.item())

        loss_i = total_loss / max(1, n_slices) if compute_loss else None
        return loss_i, logits_vol, target


## Sanity Check

In [None]:
# INR forward-pass sanity check
coords = torch.rand(2, 8, 2) * 2 - 1  # (batch, points, xy) in [-1, 1]
inr = INRMLP(in_features=2, hidden_features=64, hidden_layers=3, out_features=1).to(Cfg().device)
with torch.no_grad():
    preds = inr(coords.to(Cfg().device))
print("Coords shape:", coords.shape)
print("Preds shape:", preds.shape)


In [None]:
tmp_root = Path("/tmp/seg_sanity")
tmp_root.mkdir(parents=True, exist_ok=True)

samples = []
for idx in range(4):
    img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
    mask = (np.random.rand(256, 256) > 0.5).astype(np.uint8)
    img_path = tmp_root / f"img_{idx}.png"
    mask_path = tmp_root / f"mask_{idx}.png"
    cv2.imwrite(str(img_path), img)
    cv2.imwrite(str(mask_path), mask * 255)
    samples.append({
        "img": img_path,
        "segmask": mask_path,
        "ID": 0,
        "SIDE": "L",
        "VISIT": 0,
        "slice_idx": idx,
    })

df = pd.DataFrame(samples)
ds = KneeSegmentationDataset(df, val_trf)
sample = ds[0]
print("Image shape:", sample["image"].shape)
print("Mask shape:", sample["mask"].shape)

model = SimpleSegmentationNet(in_channels=3, num_classes=2).to(Cfg().device)
with torch.no_grad():
    logits = model(sample["image"].unsqueeze(0).to(Cfg().device))
print("Logits shape:", logits.shape)


## Training Loop
**Subtask:** Build a segmentation trainer loop (based on Assignment 3) with epoch-level logging for train/val loss and metrics.

**Files modified:** `project/final_project.ipynb`

**Sanity check:** Run a short training loop (few epochs) on the tiny synthetic dataset to confirm convergence behavior and logging.


In [None]:
import gc
import sys


class BaseTrainer:
    def __init__(self, train_df, val_df, cfg):
        self.train_df = train_df
        self.val_df = val_df
        self.cfg = cfg

        self.train_loader = None
        self.val_loader = None
        self.loss_fn = None
        self.optimizer = None
        self.model = None

    def init_model(self):
        raise NotImplementedError

    def init_run(self):
        del self.train_loader
        del self.val_loader
        del self.optimizer

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self.init_model()

        train_ds = self.cfg.dataset_train_cls(
            self.train_df,
            self.cfg.train_trf,
        )
        val_ds = self.cfg.dataset_val_cls(
            self.val_df,
            self.cfg.val_trf,
        )

        self.train_loader = DataLoader(
            train_ds,
            batch_size=self.cfg.train_bs,
            shuffle=True,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )
        self.val_loader = DataLoader(
            val_ds,
            batch_size=self.cfg.val_bs,
            shuffle=False,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )

        self.optimizer = self.cfg.optimizer_cls(
            self.model.parameters(),
            lr=self.cfg.lr,
            weight_decay=self.cfg.wd,
        )
        self.loss_fn = self.cfg.loss_fn

    def adjust_lr(self):
        if self.epoch in self.cfg.lr_drop_milestones:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] *= 0.1

    def run(self, n_epochs=None):
        self.init_run()
        if n_epochs is None:
            n_epochs = self.cfg.num_epochs
        history = []
        for self.epoch in range(n_epochs):
            self.model.train()
            train_loss = self.train_epoch()

            self.model.eval()
            val_out = self.val_epoch()
            self.post_val_hook(train_loss, val_out)

            history.append({
                "epoch": self.epoch,
                "train_loss": float(train_loss),
                **val_out,
            })
        return history

    def train_epoch(self):
        sys.stderr.flush()
        pbar = tqdm(total=len(self.train_loader), position=0, leave=True)
        running_loss = 0.0

        self.adjust_lr()

        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            loss, _ = self.pass_batch(batch)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            cur_loss = running_loss / (i + 1)

            desc = f"[{self.epoch}] Train {loss.item():.4f} / {cur_loss:.4f}"
            pbar.set_description(desc)
            pbar.update()

        pbar.close()
        return running_loss / max(len(self.train_loader), 1)

    def val_epoch(self):
        raise NotImplementedError

    def post_val_hook(self, train_loss, val_out):
        sys.stderr.flush()
        print("=" * 50)
        print(f"[{self.epoch}] --> Train loss : {train_loss:.4f}")
        print(f"[{self.epoch}] --> Val loss   : {val_out['val_loss']:.4f}")
        if "mean_iou" in val_out:
            print(f"[{self.epoch}] --> Mean IoU   : {val_out['mean_iou']:.4f}")
        for cname in self.cfg.class_names:
            key = f"jaccard/{cname}"
            if key in val_out:
                print(f"    Jaccard [{cname}] : {val_out[key]:.4f}")
        print("=" * 50)

    def pass_batch(self, batch):
        img = batch["image"].to(self.cfg.device)
        segmask = batch["mask"].to(self.cfg.device)

        logits = self.model(img)
        target = segmask.squeeze(1).long()
        loss = self.loss_fn(logits, target)

        return loss, logits


class SegmentationTrainer2D(BaseTrainer):
    def init_model(self):
        self.model = SimpleSegmentationNet(
            in_channels=self.cfg.in_channels,
            num_classes=self.cfg.num_classes,
        ).to(self.cfg.device)

    def compute_jaccard(self, logits, target):
        if target.dim() == 4:
            target = target.squeeze(1)

        preds = torch.argmax(logits, dim=1)
        num_classes = len(self.cfg.class_names)
        batch_size = preds.size(0)

        results = []
        for i in range(batch_size):
            item_res = {}
            p = preds[i]
            t = target[i]

            for c, cname in enumerate(self.cfg.class_names):
                pred_c = (p == c)
                true_c = (t == c)

                inter = (pred_c & true_c).sum().float()
                union = (pred_c | true_c).sum().float()

                union_zero = (union == 0)
                iou = torch.where(
                    union_zero,
                    torch.tensor(1.0, device=union.device),
                    inter / (union + 1e-7),
                )
                item_res[cname] = iou.item()

            results.append(item_res)

        return results

    def val_epoch(self):
        self.model.eval()
        running_loss = 0.0
        all_jaccards = []

        with torch.no_grad():
            pbar = tqdm(total=len(self.val_loader), position=0, leave=False)

            for i, batch in enumerate(self.val_loader):
                img = batch["image"].to(self.cfg.device)
                mask = batch["mask"].to(self.cfg.device)

                logits = self.model(img)
                target = mask.squeeze(1).long()
                loss = self.loss_fn(logits, target)

                running_loss += loss.item()
                batch_jaccards = self.compute_jaccard(logits, mask)
                all_jaccards.extend(batch_jaccards)

                desc = f"[{self.epoch}] Val {loss.item():.4f}"
                pbar.set_description(desc)
                pbar.update()

            pbar.close()

        mean_val_loss = running_loss / max(len(self.val_loader), 1)

        agg = {c: [] for c in self.cfg.class_names}
        for entry in all_jaccards:
            for cname, iou in entry.items():
                agg[cname].append(iou)

        mean_iou = {}
        for cname in self.cfg.class_names:
            mean_iou[cname] = float(np.mean(agg[cname])) if agg[cname] else 0.0

        out = {"val_loss": mean_val_loss}
        for cname, val in mean_iou.items():
            out[f"jaccard/{cname}"] = val
        out["mean_iou"] = float(np.mean(list(mean_iou.values()))) if mean_iou else 0.0
        return out


cfg = Cfg()
trainer = SegmentationTrainer2D(df, df, cfg)
trainer.run(n_epochs=3)


## Evaluation / Visualization

In [None]:
# TODO: Add prediction visualization and slice-wise evaluation helpers.
# Example: overlay predictions on input slices or render scan-level summaries.


## Dataset setup (Colab-ready)
Define the data root and the volume-level index used to build the slice cache.

In [None]:
import os

DATA_ROOT = Path(os.environ.get("KNEE_DATA_ROOT", "/content/knee_data"))
VOLUME_CSV = DATA_ROOT / "volumes.csv"
SLICE_CACHE_DIR = DATA_ROOT / "slice_cache"
images_dir = SLICE_CACHE_DIR / "images"
masks_dir = SLICE_CACHE_DIR / "masks"
slice_index_path = SLICE_CACHE_DIR / "slice_index.csv"

for d in [images_dir, masks_dir]:
    d.mkdir(parents=True, exist_ok=True)

if not VOLUME_CSV.exists():
    raise FileNotFoundError(
        f"Missing {VOLUME_CSV}. Expected a CSV with columns: ID, VISIT, SIDE, img, segmask."
    )

vol_df = pd.read_csv(VOLUME_CSV)
required_cols = {"ID", "VISIT", "SIDE", "img", "segmask"}
missing = required_cols - set(vol_df.columns)
if missing:
    raise ValueError(f"volumes.csv missing columns: {sorted(missing)}")

print(f"Loaded {len(vol_df)} volumes from {VOLUME_CSV}")


## Subtask 1: Recreate slice caching and data indexing

**Goal:** Convert 3D volumes into cached 2D slices and build a slice-level index.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
# Basic imports
from pathlib import Path
import random

import numpy as np
import pandas as pd
import nibabel as nib
import cv2

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Fixing one of the most annoying "features" of opencv
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


In [None]:
def vis_slice(img, lp=0, hp=99.9):
    # Normalize a slice to uint8 for PNG storage.
    img_float = img.astype(np.float32)
    low = np.percentile(img_float, lp)
    high = np.percentile(img_float, hp)
    img_norm = (img_float - low) / (high - low)
    img_norm = np.clip(img_norm, 0, 1)
    return (img_norm * 255).astype(np.uint8)


def orient_slice(slice_2d):
    # Match the orientation used in assignment 3 visualizations.
    slice_2d = np.rot90(slice_2d, k=3)
    slice_2d = np.fliplr(slice_2d)
    return slice_2d


In [None]:
# Slice caching + indexing (skip if cache already exists)
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp


def process_volume(row_dict, images_dir, masks_dir):
    row = pd.Series(row_dict)
    img_nii = nib.load(row.img)
    mask_nii = nib.load(row.segmask)

    img_data = img_nii.get_fdata()
    mask_data = mask_nii.get_fdata()

    records = []
    for slice_idx in range(img_data.shape[0]):
        img_slice = orient_slice(vis_slice(img_data[slice_idx, :, :]))
        mask_slice = orient_slice(mask_data[slice_idx, :, :])

        img_name = f"{row.ID}_{row.VISIT}_{row.SIDE}_slice{slice_idx:03d}.png"
        mask_name = f"{row.ID}_{row.VISIT}_{row.SIDE}_slice{slice_idx:03d}.png"
        img_path = images_dir / img_name
        mask_path = masks_dir / mask_name

        cv2.imwrite(str(img_path), img_slice)
        cv2.imwrite(str(mask_path), mask_slice.astype(np.uint8))

        records.append({
            "ID": row.ID,
            "VISIT": row.VISIT,
            "SIDE": row.SIDE,
            "slice_idx": slice_idx,
            "img": str(img_path),
            "segmask": str(mask_path),
        })

    return records


if slice_index_path.exists():
    slice_ds = pd.read_csv(slice_index_path)
else:
    rows = vol_df.to_dict(orient="records")
    all_records = []
    max_workers = min(4, mp.cpu_count())

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_volume, row, images_dir, masks_dir) for row in rows]
        for fut in as_completed(futures):
            all_records.extend(fut.result())

    slice_ds = pd.DataFrame(all_records)
    slice_ds.to_csv(slice_index_path, index=False)

slice_ds.head()


## Subtask 2: Patient-aware k-fold split logic

**Goal:** Create patient-level folds without leakage and select a reference fold for downstream sanity checks.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
def make_patient_folds(slice_df, n_splits=5, seed=SEED):
    patients = slice_df["ID"].astype(str) + "_" + slice_df["SIDE"]
    unique_patients = patients.unique()

    rng = np.random.default_rng(seed)
    shuffled = unique_patients.copy()
    rng.shuffle(shuffled)

    patient_folds = np.array_split(shuffled, n_splits)
    folds = []
    for fold_idx, val_patients in enumerate(patient_folds):
        val_patients = set(val_patients.tolist())
        train_patients = set(shuffled.tolist()) - val_patients

        train_df = slice_df[patients.isin(train_patients)].reset_index(drop=True)
        val_df = slice_df[patients.isin(val_patients)].reset_index(drop=True)

        folds.append({
            "fold": fold_idx,
            "train_df": train_df,
            "val_df": val_df,
            "train_patients": train_patients,
            "val_patients": val_patients,
        })

    return folds


folds = make_patient_folds(slice_ds, n_splits=5)
for fold in folds:
    print(
        f"Fold {fold['fold']}: "
        f"{len(fold['train_patients'])} train patients → {len(fold['train_df'])} slices, "
        f"{len(fold['val_patients'])} val patients → {len(fold['val_df'])} slices"
    )

reference_fold = folds[0]
train_df = reference_fold["train_df"]
val_df = reference_fold["val_df"]


### Sanity check 2
Ensure each fold has disjoint train/val patients and all patients are covered.


In [None]:
all_patients = set()
for fold in folds:
    train_patients = fold["train_patients"]
    val_patients = fold["val_patients"]
    assert train_patients.isdisjoint(val_patients), "Train/val leakage in fold."
    all_patients.update(train_patients)
    all_patients.update(val_patients)

unique_patients = set((slice_ds["ID"].astype(str) + "_" + slice_ds["SIDE"]).unique())
assert all_patients == unique_patients, "Not all patients are covered by folds."

print("All folds are disjoint and cover the full patient set.")


## Subtask 3: K-fold training/evaluation loop

**Goal:** Train and evaluate the model across folds, logging per-fold metrics and aggregate mean/std.

**Files modified/created:** `project/final_project.ipynb`


In [ ]:
def run_cross_validation(folds, cfg):
    fold_metrics = []
    for fold in folds:
        print(f"\n=== Fold {fold['fold'] + 1}/{len(folds)} ===")
        trainer = SegmentationTrainer2D(fold["train_df"], fold["val_df"], cfg)
        history = trainer.run(n_epochs=cfg.num_epochs)
        last_metrics = history[-1] if history else {}
        fold_result = {"fold": fold["fold"], **last_metrics}
        fold_metrics.append(fold_result)
        print({k: v for k, v in fold_result.items() if k != "fold"})

    metrics_df = pd.DataFrame(fold_metrics)
    metric_cols = [c for c in metrics_df.columns if c not in {"fold", "epoch"}]
    summary = {}
    for col in metric_cols:
        values = metrics_df[col].dropna().astype(float)
        if len(values) > 0:
            summary[col] = {
                "mean": float(values.mean()),
                "std": float(values.std(ddof=0)),
            }

    return metrics_df, summary


cv_cfg = Cfg()
fold_metrics_df, cv_summary = run_cross_validation(folds, cv_cfg)
print("\nPer-fold metrics:")
display(fold_metrics_df)
print("\nAggregate metrics (mean/std):")
display(pd.DataFrame(cv_summary).T)


## Subtask 4: Minimal batch sanity check

**Goal:** Load a batch and confirm shapes + mask alignment.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


class SliceDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row.img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(row.segmask, cv2.IMREAD_GRAYSCALE)

        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()
        return {"image": img, "mask": mask}


train_ds = SliceDataset(train_df)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)

batch = next(iter(train_loader))
images = batch["image"]
masks = batch["mask"]

print("Image batch shape:", images.shape)
print("Mask batch shape:", masks.shape)
assert images.shape[-2:] == masks.shape[-2:], "Image/mask spatial dimensions do not match."


### Sanity check 3
Visualize a single image/mask pair for alignment.

In [None]:
img_np = images[0].permute(1, 2, 0).numpy()
mask_np = masks[0].numpy()

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(img_np)
axes[0].set_title("Image")
axes[0].axis("off")

axes[1].imshow(mask_np, cmap="viridis")
axes[1].set_title("Mask")
axes[1].axis("off")

plt.tight_layout()
plt.show()


## Subtask 5: Signed distance field targets

**Goal:** Convert each 2D hard mask slice into a signed distance field (inside negative, outside positive) with stable normalization/clipping.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
def hard_mask_to_sdf(mask, clip_value=20.0, normalize=True, eps=1e-6):
    """Convert a 2D hard mask into a signed distance field.

    Inside the mask is negative, outside is positive.
    """
    if torch.is_tensor(mask):
        mask_np = mask.detach().cpu().numpy()
    else:
        mask_np = np.asarray(mask)

    if mask_np.ndim == 3 and mask_np.shape[0] == 1:
        mask_np = mask_np[0]

    mask_bin = (mask_np > 0).astype(np.uint8)
    dist_out = cv2.distanceTransform(1 - mask_bin, cv2.DIST_L2, 5)
    dist_in = cv2.distanceTransform(mask_bin, cv2.DIST_L2, 5)
    sdf = dist_out - dist_in

    if clip_value is not None:
        sdf = np.clip(sdf, -clip_value, clip_value)
    if normalize:
        denom = np.max(np.abs(sdf)) + eps
        sdf = sdf / denom

    return sdf.astype(np.float32)


### Sanity check 4
Visualize SDFs for a few slices and confirm boundary values are near zero.


In [None]:
def boundary_mask(binary_mask, kernel_size=3):
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    eroded = cv2.erode(binary_mask.astype(np.uint8), kernel, iterations=1)
    return (binary_mask.astype(np.uint8) - eroded).astype(bool)


num_slices = min(3, masks.shape[0])
fig, axes = plt.subplots(num_slices, 3, figsize=(9, 3 * num_slices))
if num_slices == 1:
    axes = np.expand_dims(axes, axis=0)

for i in range(num_slices):
    img_np = images[i].permute(1, 2, 0).numpy()
    mask_np = masks[i].numpy()

    sdf = hard_mask_to_sdf(mask_np, clip_value=20.0, normalize=True)

    axes[i, 0].imshow(img_np)
    axes[i, 0].set_title(f"Image {i}")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(mask_np, cmap="gray")
    axes[i, 1].set_title("Mask")
    axes[i, 1].axis("off")

    im = axes[i, 2].imshow(sdf, cmap="coolwarm", vmin=-1, vmax=1)
    axes[i, 2].set_title("SDF (normalized)")
    axes[i, 2].axis("off")

    boundary = boundary_mask(mask_np > 0)
    boundary_mean = np.mean(np.abs(sdf[boundary])) if boundary.any() else np.nan
    print(f"Slice {i}: SDF min={sdf.min():.3f}, max={sdf.max():.3f}, boundary mean |sdf|={boundary_mean:.3f}")
    assert boundary_mean < 0.2, "Boundary is not close to zero after normalization."

plt.tight_layout()
plt.show()


## Subtask 6: Losses and metrics

**Goal:** Define primary L1 loss and evaluation metrics for SDF regression.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
import torch.nn.functional as F


def l1_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Primary L1 loss for SDF regression."""
    return F.l1_loss(pred, target)


def sdf_mae(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return torch.mean(torch.abs(pred - target))


def sdf_rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return torch.sqrt(torch.mean((pred - target) ** 2))


def sdf_to_mask(sdf: torch.Tensor, thresh: float = 0.0) -> torch.Tensor:
    return sdf < thresh


def dice_iou_from_sdf(
    pred_sdf: torch.Tensor,
    target_sdf: torch.Tensor,
    thresh: float = 0.0,
    eps: float = 1e-6,
) -> dict:
    pred_mask = sdf_to_mask(pred_sdf, thresh=thresh)
    target_mask = sdf_to_mask(target_sdf, thresh=thresh)

    intersection = (pred_mask & target_mask).sum().float()
    pred_sum = pred_mask.sum().float()
    target_sum = target_mask.sum().float()
    union = pred_sum + target_sum - intersection

    dice = (2 * intersection + eps) / (pred_sum + target_sum + eps)
    iou = (intersection + eps) / (union + eps)
    return {"dice": dice, "iou": iou}


## Subtask 7: Tiny overfit sanity check

**Goal:** Overfit a single slice with INR + L1 loss to verify loss decreases.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
device = Cfg().device
slice_mask = masks[0].numpy()
sdf_target_np = hard_mask_to_sdf(slice_mask, clip_value=20.0, normalize=True)
sdf_target = torch.from_numpy(sdf_target_np).to(device)

height, width = sdf_target.shape
ys = torch.linspace(-1, 1, steps=height, device=device)
xs = torch.linspace(-1, 1, steps=width, device=device)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
coords = torch.stack([grid_x, grid_y], dim=-1)

inr = INRMLP(in_features=2, hidden_features=64, hidden_layers=3, out_features=1).to(device)
optimizer = torch.optim.Adam(inr.parameters(), lr=1e-3)
loss_fn = torch.nn.L1Loss()

loss_history = []
for step in range(60):
    optimizer.zero_grad()
    pred = inr(coords)
    loss = loss_fn(pred, sdf_target)
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())

print(f"Overfit loss: start={loss_history[0]:.4f}, end={loss_history[-1]:.4f}")
assert loss_history[-1] < loss_history[0], "Overfit sanity check failed: loss did not decrease."


## Subtask 8: Baseline vs. SDF regression comparison

**Goal:** Train a hard-mask baseline and an SDF regression model with identical splits/preprocessing, then log metrics and produce plots/tables for the report.

**Files modified/created:** `project/final_project.ipynb`

---


In [ ]:
@dataclass
class CfgSDFRegression(Cfg):
    """Configuration for SDF regression with a single-channel output."""

    num_classes: int = 1
    loss_fn = l1_loss
    class_names: list[str] | None = None

    def __post_init__(self):
        if self.lr_drop_milestones is None:
            self.lr_drop_milestones = [30]
        if self.class_names is None:
            self.class_names = ["BG", "FG"]
        self.num_classes = 1


def batch_masks_to_sdf(mask_batch, clip_value=20.0, normalize=True):
    sdf_list = []
    mask_np = mask_batch.detach().cpu().numpy()
    for i in range(mask_np.shape[0]):
        sdf = hard_mask_to_sdf(mask_np[i], clip_value=clip_value, normalize=normalize)
        sdf_list.append(sdf)
    sdf_stack = np.stack(sdf_list, axis=0)
    return torch.from_numpy(sdf_stack).to(mask_batch.device).float()


def binary_iou(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> dict:
    pred_mask = pred_mask.bool()
    true_mask = true_mask.bool()

    def _iou(pred, true):
        inter = (pred & true).sum().float()
        union = (pred | true).sum().float()
        return torch.where(union == 0, torch.tensor(1.0, device=union.device), inter / (union + 1e-7))

    iou_fg = _iou(pred_mask, true_mask)
    iou_bg = _iou(~pred_mask, ~true_mask)
    return {"BG": iou_bg.item(), "FG": iou_fg.item()}


class SDFRegressionTrainer(BaseTrainer):
    def init_model(self):
        self.model = SimpleSegmentationNet(
            in_channels=self.cfg.in_channels,
            num_classes=self.cfg.num_classes,
        ).to(self.cfg.device)

    def pass_batch(self, batch):
        img = batch["image"].to(self.cfg.device)
        mask = batch["mask"].to(self.cfg.device)

        pred = self.model(img).squeeze(1)
        sdf_target = batch_masks_to_sdf(mask, clip_value=20.0, normalize=True)
        loss = self.loss_fn(pred, sdf_target)
        return loss, pred

    def val_epoch(self):
        self.model.eval()
        running_loss = 0.0
        all_ious = []

        with torch.no_grad():
            pbar = tqdm(total=len(self.val_loader), position=0, leave=False)
            for batch in self.val_loader:
                img = batch["image"].to(self.cfg.device)
                mask = batch["mask"].to(self.cfg.device)

                pred = self.model(img).squeeze(1)
                sdf_target = batch_masks_to_sdf(mask, clip_value=20.0, normalize=True)
                loss = self.loss_fn(pred, sdf_target)
                running_loss += loss.item()

                pred_mask = pred <= 0
                true_mask = mask.squeeze(1) > 0
                for i in range(pred_mask.shape[0]):
                    all_ious.append(binary_iou(pred_mask[i], true_mask[i]))

                pbar.set_description(f"[{self.epoch}] Val {loss.item():.4f}")
                pbar.update()

            pbar.close()

        mean_val_loss = running_loss / max(len(self.val_loader), 1)
        agg = {c: [] for c in self.cfg.class_names}
        for entry in all_ious:
            for cname, iou in entry.items():
                agg[cname].append(iou)

        mean_iou = {c: float(np.mean(agg[c])) if agg[c] else 0.0 for c in self.cfg.class_names}
        out = {"val_loss": mean_val_loss}
        for cname, val in mean_iou.items():
            out[f"jaccard/{cname}"] = val
        out["mean_iou"] = float(np.mean(list(mean_iou.values()))) if mean_iou else 0.0
        return out


### INR regression trainer
Uses the coordinate-based INR to regress SDF values per slice.

In [None]:
@dataclass
class CfgINRRegression(Cfg):
    num_classes: int = 1
    loss_fn = l1_loss
    class_names: list[str] | None = None

    def __post_init__(self):
        if self.lr_drop_milestones is None:
            self.lr_drop_milestones = [30]
        if self.class_names is None:
            self.class_names = ["BG", "FG"]
        self.num_classes = 1


class INRRegressionTrainer(BaseTrainer):
    def __init__(self, train_df: pd.DataFrame, val_df: pd.DataFrame, cfg: Cfg):
        super().__init__(train_df, val_df, cfg)
        self._coord_cache = None

    def init_model(self):
        self.model = INRMLP(in_features=2, hidden_features=128, hidden_layers=4, out_features=1).to(self.cfg.device)

    def _get_coords(self, height: int, width: int, device: torch.device) -> torch.Tensor:
        if self._coord_cache is None or self._coord_cache.shape[:2] != (height, width):
            ys = torch.linspace(-1, 1, steps=height, device=device)
            xs = torch.linspace(-1, 1, steps=width, device=device)
            grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
            self._coord_cache = torch.stack([grid_x, grid_y], dim=-1)
        return self._coord_cache

    def pass_batch(self, batch):
        mask = batch["mask"].to(self.cfg.device)
        sdf_target = batch_masks_to_sdf(mask, clip_value=20.0, normalize=True)

        height, width = sdf_target.shape[-2:]
        coords = self._get_coords(height, width, self.cfg.device)
        pred = self.model(coords).unsqueeze(0).expand(sdf_target.shape[0], -1, -1)
        loss = self.loss_fn(pred, sdf_target)
        return loss, pred

    def val_epoch(self):
        self.model.eval()
        running_loss = 0.0
        all_ious = []

        with torch.no_grad():
            pbar = tqdm(total=len(self.val_loader), position=0, leave=False)
            for batch in self.val_loader:
                mask = batch["mask"].to(self.cfg.device)
                sdf_target = batch_masks_to_sdf(mask, clip_value=20.0, normalize=True)

                height, width = sdf_target.shape[-2:]
                coords = self._get_coords(height, width, self.cfg.device)
                pred = self.model(coords).unsqueeze(0).expand(sdf_target.shape[0], -1, -1)
                loss = self.loss_fn(pred, sdf_target)
                running_loss += loss.item()

                pred_mask = pred <= 0
                true_mask = mask.squeeze(1) > 0
                for i in range(pred_mask.shape[0]):
                    all_ious.append(binary_iou(pred_mask[i], true_mask[i]))

                pbar.set_description(f"[{self.epoch}] Val {loss.item():.4f}")
                pbar.update()

            pbar.close()

        mean_val_loss = running_loss / max(len(self.val_loader), 1)
        agg = {c: [] for c in self.cfg.class_names}
        for entry in all_ious:
            for cname, iou in entry.items():
                agg[cname].append(iou)

        mean_iou = {c: float(np.mean(agg[c])) if agg[c] else 0.0 for c in self.cfg.class_names}
        out = {"val_loss": mean_val_loss}
        for cname, val in mean_iou.items():
            out[f"jaccard/{cname}"] = val
        out["mean_iou"] = float(np.mean(list(mean_iou.values()))) if mean_iou else 0.0
        return out


In [ ]:
def run_baseline_vs_sdf(folds, baseline_cfg, sdf_cfg):
    rows = []
    for fold in folds:
        print(f"\n=== Fold {fold['fold'] + 1}/{len(folds)} ===")
        baseline_trainer = SegmentationTrainer2D(fold["train_df"], fold["val_df"], baseline_cfg)
        baseline_hist = baseline_trainer.run(n_epochs=baseline_cfg.num_epochs)
        baseline_metrics = baseline_hist[-1] if baseline_hist else {}
        rows.append({"fold": fold["fold"], "method": "hard_mask", **baseline_metrics})

        sdf_trainer = SDFRegressionTrainer(fold["train_df"], fold["val_df"], sdf_cfg)
        sdf_hist = sdf_trainer.run(n_epochs=sdf_cfg.num_epochs)
        sdf_metrics = sdf_hist[-1] if sdf_hist else {}
        rows.append({"fold": fold["fold"], "method": "sdf_regression", **sdf_metrics})

    return pd.DataFrame(rows)


def summarize_comparison(df: pd.DataFrame):
    metric_cols = [col for col in df.columns if col not in ["fold", "method", "epoch"]]
    summary = df.groupby("method")[metric_cols].agg(["mean", "std"])
    return summary


def plot_comparison(summary_df, out_path):
    metrics = ["mean_iou", "jaccard/FG", "val_loss"]
    fig, axes = plt.subplots(1, len(metrics), figsize=(4 * len(metrics), 4))
    if len(metrics) == 1:
        axes = [axes]
    for ax, metric in zip(axes, metrics):
        if (metric, "mean") not in summary_df.columns:
            ax.axis("off")
            continue
        means = summary_df[(metric, "mean")]
        stds = summary_df[(metric, "std")]
        means.plot(kind="bar", yerr=stds, ax=ax, capsize=4)
        ax.set_title(metric)
        ax.set_xlabel("Method")
        ax.set_ylabel(metric)
    plt.tight_layout()
    fig.savefig(out_path, dpi=150)
    return fig


### Sanity check 5
Run a tiny baseline + SDF regression comparison on a small slice subset to confirm metrics logging.

---


In [ ]:
quick_train_df = reference_fold["train_df"].head(10).reset_index(drop=True)
quick_val_df = reference_fold["val_df"].head(10).reset_index(drop=True)
quick_folds = [{"fold": 0, "train_df": quick_train_df, "val_df": quick_val_df}]

baseline_cfg = Cfg(num_epochs=1, train_bs=2, val_bs=2)
sdf_cfg = CfgSDFRegression(num_epochs=1, train_bs=2, val_bs=2)
inr_cfg = CfgINRRegression(num_epochs=1, train_bs=2, val_bs=2)

sanity_rows = []
sanity_rows.append(run_baseline_vs_sdf(quick_folds, baseline_cfg, sdf_cfg))
inr_trainer = INRRegressionTrainer(quick_train_df, quick_val_df, inr_cfg)
inr_hist = inr_trainer.run(n_epochs=inr_cfg.num_epochs)
inr_metrics = inr_hist[-1] if inr_hist else {}
sanity_rows.append(pd.DataFrame([{"fold": 0, "method": "inr_regression", **inr_metrics}]))

sanity_df = pd.concat(sanity_rows, ignore_index=True)
display(sanity_df)

summary = summarize_comparison(sanity_df)
display(summary)

if ("mean_iou", "mean") in summary.columns:
    baseline_iou = summary.loc["hard_mask", ("mean_iou", "mean")] if "hard_mask" in summary.index else None
    sdf_iou = summary.loc["sdf_regression", ("mean_iou", "mean")] if "sdf_regression" in summary.index else None
    inr_iou = summary.loc["inr_regression", ("mean_iou", "mean")] if "inr_regression" in summary.index else None
    print(f"Baseline mean IoU: {baseline_iou}")
    print(f"SDF mean IoU: {sdf_iou}")
    print(f"INR mean IoU: {inr_iou}")


### Full baseline vs. SDF regression comparison (report-ready artifacts)
Set the epoch counts below to your desired training length. This section logs fold-level metrics,
computes mean ± std tables, and saves plots/tables for the report.

---


In [ ]:
RUN_FULL_COMPARISON = False  # flip to True when ready for full experiments
BASELINE_EPOCHS = 10
SDF_EPOCHS = 10

if RUN_FULL_COMPARISON:
    baseline_cfg = Cfg(num_epochs=BASELINE_EPOCHS, train_bs=2, val_bs=2)
    sdf_cfg = CfgSDFRegression(num_epochs=SDF_EPOCHS, train_bs=2, val_bs=2)

    comparison_df = run_baseline_vs_sdf(folds, baseline_cfg, sdf_cfg)
    summary_df = summarize_comparison(comparison_df)

    artifacts_dir = Path("project/artifacts")
    artifacts_dir.mkdir(parents=True, exist_ok=True)

    comparison_csv = artifacts_dir / "baseline_vs_sdf_folds.csv"
    summary_csv = artifacts_dir / "baseline_vs_sdf_summary.csv"
    summary_md = artifacts_dir / "baseline_vs_sdf_summary.md"
    plot_path = artifacts_dir / "baseline_vs_sdf_metrics.png"

    comparison_df.to_csv(comparison_csv, index=False)
    summary_df.to_csv(summary_csv)
    summary_df.to_markdown(summary_md)

    plot_comparison(summary_df, plot_path)
    display(comparison_df)
    display(summary_df)
