# Final Project Notebook

This notebook implements incremental subtasks for the knee cartilage segmentation project.


## Subtask 1: Recreate slice caching and data indexing

**Goal:** Convert 3D volumes into cached 2D slices and build a slice-level index.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
# Basic imports
from pathlib import Path
import random

import numpy as np
import pandas as pd
import nibabel as nib
import cv2

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Fixing one of the most annoying "features" of opencv
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


In [None]:
# Configuration
# Update base_dir to point to the extracted dataset directory in Colab/Drive.
# Expected filename pattern: <PATIENT>_<VISIT>_<SIDE>_img.nii(.gz)
base_dir = Path("/content/knee_dataset")

# Cache directory for 2D slices
dataset_local_dir = Path("/content/dataset_slices")
images_dir = dataset_local_dir / "images"
masks_dir = dataset_local_dir / "masks"
images_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)

# Index path for cached slices
slice_index_path = dataset_local_dir / "slice_index.csv"


In [None]:
def vis_slice(img, lp=0, hp=99.9):
    # Normalize a slice to uint8 for PNG storage.
    img_float = img.astype(np.float32)
    low = np.percentile(img_float, lp)
    high = np.percentile(img_float, hp)
    img_norm = (img_float - low) / (high - low)
    img_norm = np.clip(img_norm, 0, 1)
    return (img_norm * 255).astype(np.uint8)


def orient_slice(slice_2d):
    # Match the orientation used in assignment 3 visualizations.
    slice_2d = np.rot90(slice_2d, k=3)
    slice_2d = np.fliplr(slice_2d)
    return slice_2d


In [None]:
# Build volume-level dataframe from the dataset directory
volume_rows = []
for img_path in base_dir.glob("*_img.nii*"):
    patient, visit, side, _ = img_path.name.split("_")
    volume_rows.append({
        "img": img_path,
        "segmask": img_path.parent / img_path.name.replace("_img", "_mask"),
        "ID": patient,
        "VISIT": visit,
        "SIDE": side,
    })

vol_df = pd.DataFrame(volume_rows)
vol_df.head()


In [None]:
# Slice caching + indexing (skip if cache already exists)
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp


def process_volume(row_dict, images_dir, masks_dir):
    row = pd.Series(row_dict)
    img_nii = nib.load(row.img)
    mask_nii = nib.load(row.segmask)

    img_data = img_nii.get_fdata()
    mask_data = mask_nii.get_fdata()

    records = []
    for slice_idx in range(img_data.shape[0]):
        img_slice = orient_slice(vis_slice(img_data[slice_idx, :, :]))
        mask_slice = orient_slice(mask_data[slice_idx, :, :])

        img_name = f"{row.ID}_{row.VISIT}_{row.SIDE}_slice{slice_idx:03d}.png"
        mask_name = f"{row.ID}_{row.VISIT}_{row.SIDE}_slice{slice_idx:03d}.png"
        img_path = images_dir / img_name
        mask_path = masks_dir / mask_name

        cv2.imwrite(str(img_path), img_slice)
        cv2.imwrite(str(mask_path), mask_slice.astype(np.uint8))

        records.append({
            "ID": row.ID,
            "VISIT": row.VISIT,
            "SIDE": row.SIDE,
            "slice_idx": slice_idx,
            "img": str(img_path),
            "segmask": str(mask_path),
        })

    return records


if slice_index_path.exists():
    slice_ds = pd.read_csv(slice_index_path)
else:
    rows = vol_df.to_dict(orient="records")
    all_records = []
    max_workers = min(4, mp.cpu_count())

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_volume, row, images_dir, masks_dir) for row in rows]
        for fut in as_completed(futures):
            all_records.extend(fut.result())

    slice_ds = pd.DataFrame(all_records)
    slice_ds.to_csv(slice_index_path, index=False)

slice_ds.head()


### Sanity check 1
Verify cached files exist and the index is populated.

In [None]:
assert len(slice_ds) > 0, "Slice index is empty."

example = slice_ds.iloc[0]
assert Path(example.img).exists(), "Cached image not found."
assert Path(example.segmask).exists(), "Cached mask not found."

print(f"Cached {len(slice_ds)} slices.")


## Subtask 2: Patient-aware split logic

**Goal:** Create train/val/test splits without patient leakage.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
from sklearn.model_selection import train_test_split


def make_patient_splits(slice_df, test_size=0.2, val_size=0.2, seed=SEED):
    patients = slice_df["ID"].astype(str) + "_" + slice_df["SIDE"]
    unique_patients = patients.unique()

    trainval_patients, test_patients = train_test_split(
        unique_patients,
        test_size=test_size,
        random_state=seed,
        shuffle=True,
    )

    train_patients, val_patients = train_test_split(
        trainval_patients,
        test_size=val_size,
        random_state=seed,
        shuffle=True,
    )

    train_df = slice_df[patients.isin(train_patients)].reset_index(drop=True)
    val_df = slice_df[patients.isin(val_patients)].reset_index(drop=True)
    test_df = slice_df[patients.isin(test_patients)].reset_index(drop=True)

    return train_df, val_df, test_df, set(train_patients), set(val_patients), set(test_patients)


train_df, val_df, test_df, train_patients, val_patients, test_patients = make_patient_splits(slice_ds)

print(f"Train patients: {len(train_patients)} → {len(train_df)} slices")
print(f"Val   patients: {len(val_patients)} → {len(val_df)} slices")
print(f"Test  patients: {len(test_patients)} → {len(test_df)} slices")


### Sanity check 2
Ensure no patient overlap across splits.

In [None]:
assert train_patients.isdisjoint(val_patients), "Train/val patient leakage detected."
assert train_patients.isdisjoint(test_patients), "Train/test patient leakage detected."
assert val_patients.isdisjoint(test_patients), "Val/test patient leakage detected."

print("Patient splits are disjoint.")


## Subtask 3: Minimal batch sanity check

**Goal:** Load a batch and confirm shapes + mask alignment.

**Files modified/created:** `project/final_project.ipynb`


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


class SliceDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row.img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(row.segmask, cv2.IMREAD_GRAYSCALE)

        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()
        return {"image": img, "mask": mask}


train_ds = SliceDataset(train_df)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)

batch = next(iter(train_loader))
images = batch["image"]
masks = batch["mask"]

print("Image batch shape:", images.shape)
print("Mask batch shape:", masks.shape)
assert images.shape[-2:] == masks.shape[-2:], "Image/mask spatial dimensions do not match."


### Sanity check 3
Visualize a single image/mask pair for alignment.

In [None]:
img_np = images[0].permute(1, 2, 0).numpy()
mask_np = masks[0].numpy()

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(img_np)
axes[0].set_title("Image")
axes[0].axis("off")

axes[1].imshow(mask_np, cmap="viridis")
axes[1].set_title("Mask")
axes[1].axis("off")

plt.tight_layout()
plt.show()
