# Final Project Notebook
This notebook implements the project pipeline in the intended execution order: **data → targets → model → loss/metrics → training → evaluation**. Each section below adds the corresponding building blocks, starting from data loading and augmentations and ending with training and evaluation placeholders.


## Imports & Seeds

In [None]:
import random
from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import solt
import solt.transforms as slt
from tqdm.auto import tqdm

SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


## Data Loading

In [None]:
class KneeSegmentationDataset(Dataset):
    """Slice-wise dataset with paired SOLT transforms."""

    def __init__(self, df: pd.DataFrame, transforms: solt.Stream | None):
        self.dataset = df.reset_index(drop=True)
        self.trf = transforms

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> dict:
        entry = self.dataset.iloc[idx]
        img = cv2.imread(str(entry.img), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(entry.segmask), cv2.IMREAD_GRAYSCALE)
        if img is None or mask is None:
            raise FileNotFoundError(f"Missing image/mask for index {idx}")

        if self.trf is not None:
            res = self.trf({"image": img, "mask": mask}, return_torch=False)
            img, mask = res.data

        img = img.astype(np.float32) / 255.0
        mask = mask.astype(np.int64)

        img_t = torch.from_numpy(img).permute(2, 0, 1)
        mask_t = torch.from_numpy(mask).unsqueeze(0)

        return {"image": img_t, "mask": mask_t}


class KneeSegmentationDatasetVol(Dataset):
    """Volume-wise dataset that groups slices by scan."""

    def __init__(self, df: pd.DataFrame, transforms: solt.Stream | None, load_mask: bool = True):
        super().__init__()
        self.df = df
        self.trf = transforms
        self.load_mask = load_mask

        self.scans = []
        for _, g in self.df.groupby(["ID", "SIDE", "VISIT"]):
            g = g.sort_values(by="slice_idx").reset_index(drop=True)
            self.scans.append(g)

    def __len__(self) -> int:
        return len(self.scans)

    def __getitem__(self, index: int) -> dict:
        return self.load_volume(self.scans[index], self.trf, load_mask=self.load_mask)

    @staticmethod
    def _to_chw_tensor(img: np.ndarray | torch.Tensor) -> torch.Tensor:
        if isinstance(img, np.ndarray):
            tensor = torch.from_numpy(img)
            if tensor.dim() == 3 and tensor.shape[-1] in (1, 3):
                tensor = tensor.permute(2, 0, 1)
            tensor = tensor.float() / 255.0
            return tensor

        if torch.is_tensor(img):
            tensor = img.float()
            if tensor.dim() == 3 and tensor.shape[-1] in (1, 3):
                tensor = tensor.permute(2, 0, 1)
            if tensor.max() > 1.0:
                tensor = tensor / 255.0
            return tensor

        raise TypeError("Unsupported image type for volume loading")

    @staticmethod
    def _to_hw_long(mask: np.ndarray | torch.Tensor) -> torch.Tensor:
        if isinstance(mask, np.ndarray):
            tensor = torch.from_numpy(mask)
        elif torch.is_tensor(mask):
            tensor = mask
        else:
            raise TypeError("Unsupported mask type for volume loading")

        if tensor.dim() == 3 and tensor.size(0) == 1:
            tensor = tensor.squeeze(0)
        return tensor.long()

    @staticmethod
    def load_volume(vol_df: pd.DataFrame, transform: solt.Stream | None, load_mask: bool = True) -> dict:
        images, masks = [], []
        for _, entry in vol_df.iterrows():
            img = cv2.imread(str(entry.img), cv2.IMREAD_COLOR)
            if img is None:
                raise FileNotFoundError(f"Missing image: {entry.img}")
            images.append(img)

            if load_mask:
                mask = cv2.imread(str(entry.segmask), cv2.IMREAD_GRAYSCALE)
                if mask is None:
                    raise FileNotFoundError(f"Missing mask: {entry.segmask}")
                masks.append(mask)

        if transform is not None:
            if load_mask:
                res = transform({"images": images, "masks": masks})
            else:
                res = transform({"images": images})
            images = res["images"]
            if load_mask:
                masks = res["masks"]

        scan = torch.stack([KneeSegmentationDatasetVol._to_chw_tensor(img) for img in images], dim=0)
        if load_mask:
            mask = torch.stack([KneeSegmentationDatasetVol._to_hw_long(m) for m in masks], dim=0)
        else:
            mask = None

        return {"scan": scan, "mask": mask}


## Augmentation

In [None]:
train_trf = solt.Stream([
    slt.Resize((256, 256)),
    slt.Flip(p=0.5, axis=1),
    slt.Crop((224, 224), crop_mode="r"),
    slt.GammaCorrection(gamma_range=0.1, p=1),
])

val_trf = solt.Stream([
    slt.Resize((256, 256)),
])


## Config

In [None]:
@dataclass
class Cfg:
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    in_channels: int = 3
    num_classes: int = 2

    dataset_train_cls = KneeSegmentationDataset
    dataset_val_cls = KneeSegmentationDataset

    optimizer_cls = torch.optim.AdamW
    loss_fn = torch.nn.CrossEntropyLoss()

    lr: float = 1e-3
    wd: float = 5e-5
    num_epochs: int = 1
    n_workers: int = 0
    train_bs: int = 2
    val_bs: int = 2

    lr_drop_milestones: list[int] = None
    class_names: list[str] = None

    train_trf: solt.Stream | None = train_trf
    val_trf: solt.Stream | None = val_trf

    def __post_init__(self):
        if self.lr_drop_milestones is None:
            self.lr_drop_milestones = [30]
        if self.class_names is None:
            self.class_names = ["BG", "FG"]
        self.num_classes = len(self.class_names)


class CfgHighRes(Cfg):
    train_trf = solt.Stream([
        slt.Flip(p=0.5, axis=1),
        slt.Crop((320, 320), crop_mode="r"),
        slt.GammaCorrection(gamma_range=0.1, p=1),
    ])

    val_trf = solt.Stream([])


class CfgVol(Cfg):
    dataset_val_cls = KneeSegmentationDatasetVol
    val_bs = 1


## Trainer

In [None]:
class SimpleSegmentationNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class BaseTrainer:
    def __init__(self, train_df: pd.DataFrame, val_df: pd.DataFrame, cfg: Cfg):
        self.train_df = train_df
        self.val_df = val_df
        self.cfg = cfg

        self.train_loader = None
        self.val_loader = None
        self.loss_fn = None
        self.optimizer = None
        self.model = None

    def init_model(self):
        raise NotImplementedError

    def init_run(self):
        self.init_model()

        train_ds = self.cfg.dataset_train_cls(self.train_df, self.cfg.train_trf)
        val_ds = self.cfg.dataset_val_cls(self.val_df, self.cfg.val_trf)

        self.train_loader = DataLoader(
            train_ds,
            batch_size=self.cfg.train_bs,
            shuffle=True,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )
        self.val_loader = DataLoader(
            val_ds,
            batch_size=self.cfg.val_bs,
            shuffle=False,
            num_workers=self.cfg.n_workers,
            pin_memory=True,
        )

        self.optimizer = self.cfg.optimizer_cls(
            self.model.parameters(),
            lr=self.cfg.lr,
            weight_decay=self.cfg.wd,
        )
        self.loss_fn = self.cfg.loss_fn

    def adjust_lr(self):
        if self.cfg.lr_drop_milestones and self.epoch in self.cfg.lr_drop_milestones:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] *= 0.1

    def run(self, n_epochs: int | None = None):
        self.init_run()
        if n_epochs is None:
            n_epochs = self.cfg.num_epochs

        for self.epoch in range(n_epochs):
            self.model.train()
            train_loss = self.train_epoch()
            self.model.eval()
            val_out = self.val_epoch()
            self.post_val_hook(train_loss, val_out)

    def train_epoch(self):
        pbar = tqdm(self.train_loader, desc=f"[{self.epoch}] Train")
        running_loss = 0.0
        self.adjust_lr()

        for i, batch in enumerate(pbar):
            self.optimizer.zero_grad()
            loss, _ = self.pass_batch(batch)
            loss.backward()
            self.optimizer.step()

            running_loss += float(loss.item())
            pbar.set_postfix({"loss": running_loss / (i + 1)})

        return running_loss / max(1, len(self.train_loader))

    def val_epoch(self):
        raise NotImplementedError

    def post_val_hook(self, train_loss, val_out):
        print("=" * 50)
        print(f"[{self.epoch}] --> Train loss: {train_loss:.4f}")
        print(f"[{self.epoch}] --> Val loss: {val_out['val_loss']:.4f}")
        print("=" * 50)

    def pass_batch(self, batch):
        img = batch["image"].to(self.cfg.device)
        mask = batch["mask"].to(self.cfg.device)

        logits = self.model(img)
        target = mask.squeeze(1).long()
        loss = self.loss_fn(logits, target)
        return loss, logits


class SegmentationTrainer2D(BaseTrainer):
    def init_model(self):
        self.model = SimpleSegmentationNet(
            in_channels=self.cfg.in_channels,
            num_classes=self.cfg.num_classes,
        ).to(self.cfg.device)

    @torch.no_grad()
    def val_epoch(self):
        running_loss = 0.0
        pbar = tqdm(self.val_loader, desc=f"[{self.epoch}] Val", leave=False)
        for i, batch in enumerate(pbar):
            loss, _ = self.pass_batch(batch)
            running_loss += float(loss.item())
            pbar.set_postfix({"loss": running_loss / (i + 1)})
        return {"val_loss": running_loss / max(1, len(self.val_loader))}


class SegmentationTrainer3D(SegmentationTrainer2D):
    def pass_eval_batch(self, batch, compute_loss: bool = False):
        scan = batch["scan"].to(self.cfg.device)
        target = batch.get("mask")
        if target is not None:
            target = target.to(self.cfg.device)

        batch_size, n_slices, _, height, width = scan.shape
        logits_vol = torch.zeros(
            batch_size, self.cfg.num_classes, n_slices, height, width,
            device=self.cfg.device,
        )

        total_loss = 0.0
        for slice_idx in range(n_slices):
            x_s = scan[:, slice_idx, ...]
            logits_s = self.model(x_s)
            logits_vol[:, :, slice_idx, :, :] = logits_s

            if compute_loss and target is not None:
                y_s = target[:, slice_idx, ...]
                loss_s = self.loss_fn(logits_s, y_s)
                total_loss += float(loss_s.item())

        loss_i = total_loss / max(1, n_slices) if compute_loss else None
        return loss_i, logits_vol, target


## Sanity Check

In [None]:
tmp_root = Path("/tmp/seg_sanity")
tmp_root.mkdir(parents=True, exist_ok=True)

samples = []
for idx in range(4):
    img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
    mask = (np.random.rand(256, 256) > 0.5).astype(np.uint8)
    img_path = tmp_root / f"img_{idx}.png"
    mask_path = tmp_root / f"mask_{idx}.png"
    cv2.imwrite(str(img_path), img)
    cv2.imwrite(str(mask_path), mask * 255)
    samples.append({
        "img": img_path,
        "segmask": mask_path,
        "ID": 0,
        "SIDE": "L",
        "VISIT": 0,
        "slice_idx": idx,
    })

df = pd.DataFrame(samples)
ds = KneeSegmentationDataset(df, val_trf)
sample = ds[0]
print("Image shape:", sample["image"].shape)
print("Mask shape:", sample["mask"].shape)

model = SimpleSegmentationNet(in_channels=3, num_classes=2).to(Cfg().device)
with torch.no_grad():
    logits = model(sample["image"].unsqueeze(0).to(Cfg().device))
print("Logits shape:", logits.shape)


## Training Loop

In [None]:
cfg = Cfg()
trainer = SegmentationTrainer2D(df, df, cfg)
trainer.run(n_epochs=1)


## Evaluation / Visualization

In [None]:
# TODO: Add prediction visualization and slice-wise evaluation helpers.
# Example: overlay predictions on input slices or render scan-level summaries.
