In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
matinmo_cataract_101_path = kagglehub.dataset_download('matinmo/cataract-101')

print('Data source import complete.')


# Cataract-101 Surgical Phase Recognition Prototype

This notebook builds a lightweight clip-based classifier for surgical phase recognition using the Cataract-101 dataset. It leverages a pretrained ResNet50 backbone applied on per-frame images, aggregates temporal features, and supports rapid experimentation on Kaggle by sampling a subset of videos.


In [None]:
#cell 2
import os
import math
import random
import copy
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np
import pandas as pd
from PIL import Image
from collections import Counter

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from einops import rearrange

from sklearn.metrics import confusion_matrix, classification_report

import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from IPython.display import display

# Configuration
DATA_ROOT = Path("/kaggle/input/cataract-101/cataract-101")
VIDEOS_DIR = DATA_ROOT / "videos"
PHASE_FILE = DATA_ROOT / "phases.csv"
VIDEO_META_FILE = DATA_ROOT / "videos.csv"
ANNOTATION_FILE = DATA_ROOT / "annotations.csv"
OUTPUT_DIR = Path("/kaggle/working/output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

SEED = 1337
SAMPLE_VIDEOS = 8
SAMPLE_FRACTION = None  # e.g., 0.3 to use 30% of the videos
TRAIN_VIDEO_LIMIT = SAMPLE_VIDEOS
VAL_VIDEO_LIMIT = max(2, SAMPLE_VIDEOS // 4) if SAMPLE_VIDEOS else None

CLIP_LEN = 16
FRAME_STEP = 2            # temporal stride between frames inside a clip
CLIP_STRIDE = max(1, CLIP_LEN // 2)  # stride between consecutive clips
MAX_CLIPS_PER_VIDEO = None            # optionally cap number of clips per video
FRAME_SAMPLING_STEP = 1              # expand annotations into dense frame indices

BATCH_SIZE = 2
VAL_BATCH_SIZE = 2
NUM_WORKERS = 2
NUM_EPOCHS = 3
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
GRAD_ACCUM_STEPS = 1
USE_AMP = torch.cuda.is_available()
FREEZE_BACKBONE = False
TRAINABLE_BACKBONE_LAYERS = ("layer4",)
DROPOUT = 0.2

INFERENCE_BATCH_SIZE = 4
INFERENCE_CLIP_STRIDE = CLIP_STRIDE
SAVE_BEST_MODEL_PATH = OUTPUT_DIR / "best_clip_resnet50.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pd.options.display.max_columns = 50


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(SEED)
print(f"Using device: {device}")
print(f"Data root: {DATA_ROOT}")
print(f"Videos directory: {VIDEOS_DIR}")


In [None]:
#cell 3
phase_df = pd.read_csv(PHASE_FILE, sep=';')
phase_df["Phase"] = phase_df["Phase"].astype(int)
phase_df["Meaning"] = phase_df["Meaning"].astype(str)
phase_df["PhaseName"] = phase_df.apply(lambda row: f"{int(row['Phase'])}: {row['Meaning']}", axis=1)
phase_to_name = dict(zip(phase_df["Phase"], phase_df["PhaseName"]))
phase_to_idx = {int(phase_id): idx for idx, phase_id in enumerate(sorted(phase_to_name.keys()))}
IDX_TO_PHASE = {idx: phase_to_name[phase_id] for phase_id, idx in phase_to_idx.items()}
PHASE_NAMES = [IDX_TO_PHASE[idx] for idx in range(len(IDX_TO_PHASE))]
print(f"Detected {len(PHASE_NAMES)} phases: {PHASE_NAMES}")
display(phase_df)

video_meta_df = pd.read_csv(VIDEO_META_FILE, sep=';')
video_meta_df["VideoID"] = video_meta_df["VideoID"].astype(int)
video_meta_df["Frames"] = video_meta_df["Frames"].astype(int)
video_meta_df["FPS"] = pd.to_numeric(video_meta_df["FPS"], errors='coerce').fillna(25.0)
video_meta_df["Surgeon"] = video_meta_df["Surgeon"].astype(int)
video_meta_df["Experience"] = video_meta_df["Experience"].astype(int)
print(f"Video metadata entries: {len(video_meta_df)}")
display(video_meta_df.head())

annotation_df = pd.read_csv(ANNOTATION_FILE, sep=';')
annotation_df["VideoID"] = annotation_df["VideoID"].astype(int)
annotation_df["FrameNo"] = annotation_df["FrameNo"].astype(int)
annotation_df["Phase"] = annotation_df["Phase"].astype(int)
annotation_df["PhaseIdx"] = annotation_df["Phase"].map(phase_to_idx)
if annotation_df["PhaseIdx"].isna().any():
    missing_rows = annotation_df[annotation_df["PhaseIdx"].isna()].head()
    print("[WARN] Some annotations reference unknown phases; previewing the first few rows:")
    display(missing_rows)
    annotation_df = annotation_df.dropna(subset=["PhaseIdx"]).copy()
annotation_df["PhaseIdx"] = annotation_df["PhaseIdx"].astype(int)
print(f"Annotation rows: {len(annotation_df):,}")
display(annotation_df.head())

phase_counts = annotation_df["PhaseIdx"].value_counts().sort_index()
phase_count_df = pd.DataFrame({"phase": [IDX_TO_PHASE[idx] for idx in phase_counts.index],
                               "count": phase_counts.values})
display(phase_count_df)

video_files = sorted(VIDEOS_DIR.glob("case_*.mp4"))
print(f"Discovered {len(video_files)} video files under {VIDEOS_DIR}.")
if video_files:
    preview_df = pd.DataFrame({"video_path": [str(p.relative_to(DATA_ROOT)) for p in video_files[:5]]})
    display(preview_df)

missing_meta = sorted(set(annotation_df["VideoID"]) - set(video_meta_df["VideoID"]))
missing_files = [vid for vid in annotation_df["VideoID"].unique()
                 if not (VIDEOS_DIR / f"case_{vid}.mp4").exists()]
print(f"Videos in annotations without metadata: {len(missing_meta)}")
print(f"Videos with missing files: {len(missing_files)}")


In [None]:
#cell 4
@dataclass
class VideoRecord:
    video_id: str
    video_numeric_id: int
    media_path: Path
    total_frames: int
    fps: float
    frame_numbers: np.ndarray
    phase_indices: np.ndarray
    segments: List[Tuple[int, int, int]]  # (start_frame, end_frame, phase_idx)
    annotation: pd.DataFrame


@dataclass
class ClipSample:
    video_id: str
    frame_indices: np.ndarray
    label: int


def build_video_records(
    annotations: pd.DataFrame,
    video_meta: pd.DataFrame,
    videos_dir: Path,
    frame_sampling_step: int = 1,
) -> List[VideoRecord]:
    records: List[VideoRecord] = []
    meta_index: Dict[int, pd.Series] = {int(row.VideoID): row for _, row in video_meta.iterrows()}
    for video_id, group in annotations.groupby("VideoID"):
        video_numeric = int(video_id)
        meta = meta_index.get(video_numeric)
        if meta is None:
            print(f"[WARN] VideoID {video_numeric} has annotations but no metadata; skipping.")
            continue
        media_path = videos_dir / f"case_{video_numeric}.mp4"
        if not media_path.exists():
            print(f"[WARN] Media file {media_path} not found; skipping video {video_numeric}.")
            continue
        total_frames = int(meta["Frames"])
        fps = float(meta.get("FPS", 25.0)) if not pd.isna(meta.get("FPS", np.nan)) else 25.0
        group_sorted = group.sort_values("FrameNo").reset_index(drop=True)
        start_frames = group_sorted["FrameNo"].to_numpy(dtype=np.int32)
        phase_indices = group_sorted["PhaseIdx"].to_numpy(dtype=np.int32)
        if start_frames.size == 0:
            print(f"[WARN] Video {video_numeric} has no frame transitions; skipping.")
            continue
        if start_frames[0] > 0:
            start_frames = np.insert(start_frames, 0, 0)
            phase_indices = np.insert(phase_indices, 0, phase_indices[0])
        segment_ends = np.append(start_frames[1:] - 1, total_frames - 1)
        segments: List[Tuple[int, int, int]] = []
        sampled_frames: List[np.ndarray] = []
        sampled_labels: List[np.ndarray] = []
        for start, end, phase_idx in zip(start_frames, segment_ends, phase_indices):
            start = int(max(0, start))
            end = int(min(end, total_frames - 1))
            if end < start:
                end = start
            segments.append((start, end, int(phase_idx)))
            indices = np.arange(start, end + 1, frame_sampling_step, dtype=np.int32)
            if indices.size == 0:
                indices = np.array([start], dtype=np.int32)
            sampled_frames.append(indices)
            sampled_labels.append(np.full(indices.shape, int(phase_idx), dtype=np.int32))
        frame_numbers = np.concatenate(sampled_frames)
        label_array = np.concatenate(sampled_labels)
        order = np.argsort(frame_numbers)
        frame_numbers = frame_numbers[order]
        label_array = label_array[order]
        records.append(
            VideoRecord(
                video_id=str(video_numeric),
                video_numeric_id=video_numeric,
                media_path=media_path,
                total_frames=total_frames,
                fps=fps,
                frame_numbers=frame_numbers,
                phase_indices=label_array,
                segments=segments,
                annotation=group_sorted.copy(),
            )
        )
    records.sort(key=lambda rec: rec.video_numeric_id)
    print(f"Prepared metadata for {len(records)} videos.")
    return records


def subset_records(
    records: Sequence[VideoRecord],
    limit: Optional[int] = None,
    fraction: Optional[float] = None,
    seed: int = SEED,
) -> List[VideoRecord]:
    if not records:
        return []
    rng = random.Random(seed)
    indices = list(range(len(records)))
    rng.shuffle(indices)
    if fraction is not None:
        limit = max(1, int(len(records) * fraction))
    if limit is not None:
        indices = indices[: min(limit, len(indices))]
    return [records[i] for i in sorted(indices)]


def split_train_val(
    records: Sequence[VideoRecord],
    val_ratio: float = 0.2,
    seed: int = SEED,
) -> Tuple[List[VideoRecord], List[VideoRecord]]:
    records = list(records)
    if not records:
        return [], []
    rng = random.Random(seed)
    rng.shuffle(records)
    val_count = max(1, int(len(records) * val_ratio)) if len(records) > 1 else 1
    val_records = records[:val_count]
    train_records = records[val_count:]
    if not train_records:
        train_records, val_records = val_records, train_records
    print(f"Split into {len(train_records)} train and {len(val_records)} validation videos.")
    return train_records, val_records


class ClipDataset(Dataset):
    def __init__(
        self,
        records: Sequence[VideoRecord],
        transform: transforms.Compose,
        clip_len: int = 16,
        frame_step: int = 1,
        clip_stride: int = 8,
        max_clips_per_video: Optional[int] = None,
    ) -> None:
        self.records = list(records)
        self.transform = transform
        self.clip_len = clip_len
        self.frame_step = max(1, frame_step)
        self.clip_stride = max(1, clip_stride)
        self.max_clips_per_video = max_clips_per_video
        self.samples: List[ClipSample] = []
        self.video_map: Dict[str, VideoRecord] = {rec.video_id: rec for rec in self.records}
        self._build_index()

    def _build_index(self) -> None:
        for rec in self.records:
            frames = rec.frame_numbers
            labels = rec.phase_indices
            if len(frames) < self.clip_len:
                continue
            max_start = len(frames) - (self.clip_len - 1) * self.frame_step
            if max_start <= 0:
                continue
            start_positions = list(range(0, max_start, self.clip_stride))
            final_candidate = max_start - 1
            if start_positions:
                if final_candidate > start_positions[-1]:
                    start_positions.append(final_candidate)
            else:
                start_positions = [0]
            clip_count = 0
            for start in start_positions:
                stop = start + self.clip_len * self.frame_step
                indices = frames[start:stop:self.frame_step]
                if len(indices) < self.clip_len:
                    continue
                label_slice = labels[start:stop:self.frame_step]
                if len(label_slice) < self.clip_len:
                    continue
                majority_label = Counter(label_slice).most_common(1)[0][0]
                self.samples.append(
                    ClipSample(
                        video_id=rec.video_id,
                        frame_indices=indices.astype(np.int64),
                        label=int(majority_label),
                    )
                )
                clip_count += 1
                if self.max_clips_per_video and clip_count >= self.max_clips_per_video:
                    break
        print(f"Built {len(self.samples)} clip samples from {len(self.records)} videos.")

    def __len__(self) -> int:
        return len(self.samples)

    def _load_clip_tensor(self, rec: VideoRecord, frame_indices: np.ndarray) -> torch.Tensor:
        cap = cv2.VideoCapture(str(rec.media_path))
        if not cap.isOpened():
            raise RuntimeError(f"Unable to open video at {rec.media_path}")
        frames: List[torch.Tensor] = []
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
            success, frame = cap.read()
            if not success:
                break
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(frame_rgb)
            frames.append(self.transform(image))
        cap.release()
        if len(frames) != len(frame_indices):
            if frames:
                while len(frames) < len(frame_indices):
                    frames.append(frames[-1])
            else:
                raise RuntimeError(f"Failed to read any frames for video {rec.video_id}")
        clip_tensor = torch.stack(frames, dim=0)  # (T, C, H, W)
        return clip_tensor

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        sample = self.samples[idx]
        record = self.video_map[sample.video_id]
        clip_tensor = self._load_clip_tensor(record, sample.frame_indices)
        return clip_tensor, sample.label


def build_ground_truth_timeline(record: VideoRecord) -> np.ndarray:
    timeline = np.full(record.total_frames, fill_value=-1, dtype=np.int32)
    for start, end, label_idx in record.segments:
        start = int(max(0, start))
        end = int(min(end, record.total_frames - 1))
        timeline[start : end + 1] = label_idx
    if (timeline == -1).any():
        valid_indices = np.where(timeline != -1)[0]
        if valid_indices.size:
            first_label = timeline[valid_indices[0]]
            last_label = timeline[valid_indices[-1]]
            timeline[: valid_indices[0]] = first_label
            timeline[valid_indices[-1] + 1 :] = last_label
    return timeline


In [None]:
# cell 5
video_records = build_video_records(
    annotations=annotation_df,
    video_meta=video_meta_df,
    videos_dir=VIDEOS_DIR,
    frame_sampling_step=FRAME_SAMPLING_STEP,
)

train_records, val_records = split_train_val(video_records, val_ratio=0.2, seed=SEED)
train_records = subset_records(train_records, limit=TRAIN_VIDEO_LIMIT, fraction=SAMPLE_FRACTION, seed=SEED)
val_records = subset_records(val_records, limit=VAL_VIDEO_LIMIT, fraction=None, seed=SEED)

print(f"Detected phases ({len(PHASE_NAMES)}): {PHASE_NAMES}")
print(f"Using {len(train_records)} videos for training and {len(val_records)} for validation.")
print(f"Train videos (sample): {[rec.video_id for rec in train_records[:5]]}")
print(f"Validation videos (sample): {[rec.video_id for rec in val_records[:5]]}")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

train_dataset = ClipDataset(
    records=train_records,
    transform=train_transform,
    clip_len=CLIP_LEN,
    frame_step=FRAME_STEP,
    clip_stride=CLIP_STRIDE,
    max_clips_per_video=MAX_CLIPS_PER_VIDEO,
)

val_dataset = ClipDataset(
    records=val_records,
    transform=val_transform,
    clip_len=CLIP_LEN,
    frame_step=FRAME_STEP,
    clip_stride=CLIP_STRIDE,
    max_clips_per_video=MAX_CLIPS_PER_VIDEO,
)

print(f"Train clips: {len(train_dataset)}, Validation clips: {len(val_dataset)}")

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False,
)

if len(train_dataset) > 0:
    sample_clip, sample_label = train_dataset[0]
    print(f"Sample clip shape: {sample_clip.shape}, label index: {sample_label} ({PHASE_NAMES[sample_label]})")
else:
    print("Warning: training dataset is empty; adjust sampling parameters or verify annotations.")


In [None]:
# cell 6
class ClipResNet(nn.Module):
    def __init__(
        self,
        num_classes: int,
        pretrained: bool = True,
        dropout: float = 0.2,
        freeze_backbone: bool = False,
        trainable_layers: Sequence[str] = ("layer4",),
    ) -> None:
        super().__init__()
        weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        backbone = models.resnet50(weights=weights)
        if freeze_backbone:
            for param in backbone.parameters():
                param.requires_grad = False
            for layer_name in trainable_layers:
                layer = getattr(backbone, layer_name, None)
                if layer is not None:
                    for param in layer.parameters():
                        param.requires_grad = True
        self.feature_dim = backbone.fc.in_features
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
        self.head = nn.Linear(self.feature_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 5:
            raise ValueError(f"Expected input shape (B, T, C, H, W) but received {tuple(x.shape)}")
        b, t, c, h, w = x.shape
        x = rearrange(x, "b t c h w -> (b t) c h w")
        feats = self.backbone(x)
        feats = torch.flatten(feats, 1)
        feats = feats.view(b, t, self.feature_dim)
        clip_features = feats.mean(dim=1)
        clip_features = self.dropout(clip_features)
        logits = self.head(clip_features)
        return logits


model = ClipResNet(
    num_classes=len(PHASE_NAMES),
    pretrained=True,
    dropout=DROPOUT,
    freeze_backbone=FREEZE_BACKBONE,
    trainable_layers=TRAINABLE_BACKBONE_LAYERS,
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, NUM_EPOCHS))
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print(model)
print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


In [None]:
#cell 7
def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    scaler: Optional[torch.cuda.amp.GradScaler] = None,
    grad_accum_steps: int = 1,
) -> Dict[str, float]:
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    optimizer.zero_grad()
    pbar = tqdm(dataloader, desc=f"Epoch {epoch} [train]", leave=False)
    for step, (clips, targets) in enumerate(pbar):
        clips = clips.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=scaler is not None and USE_AMP):
            outputs = model(clips)
            loss = criterion(outputs, targets)
            loss = loss / grad_accum_steps
        if scaler is not None and USE_AMP:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        if (step + 1) % grad_accum_steps == 0:
            if scaler is not None and USE_AMP:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item() * grad_accum_steps
        preds = outputs.argmax(dim=1)
        running_correct += (preds == targets).sum().item()
        running_total += targets.size(0)
        pbar.set_postfix(loss=running_loss / (step + 1),
                         acc=running_correct / max(1, running_total))
    return {
        "loss": running_loss / max(1, len(dataloader)),
        "acc": running_correct / max(1, running_total),
    }


@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, np.ndarray]:
    model.eval()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    all_preds: List[torch.Tensor] = []
    all_targets: List[torch.Tensor] = []
    for clips, targets in tqdm(dataloader, desc="Validation", leave=False):
        clips = clips.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=False):
            outputs = model(clips)
            loss = criterion(outputs, targets)
        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_correct += (preds == targets).sum().item()
        running_total += targets.size(0)
        all_preds.append(preds.cpu())
        all_targets.append(targets.cpu())
    if all_preds:
        preds_tensor = torch.cat(all_preds).numpy()
        targets_tensor = torch.cat(all_targets).numpy()
    else:
        preds_tensor = np.array([])
        targets_tensor = np.array([])
    return {
        "loss": running_loss / max(1, len(dataloader)),
        "acc": running_correct / max(1, running_total),
        "preds": preds_tensor,
        "targets": targets_tensor,
    }


In [None]:
#cell 8
if len(train_dataset) == 0 or len(val_dataset) == 0:
    raise RuntimeError("Training or validation dataset is empty. Adjust sampling parameters or verify the metadata before proceeding.")

history: List[Dict[str, float]] = []
best_state = copy.deepcopy(model.state_dict())
best_metrics = None
best_acc = -math.inf

for epoch in range(1, NUM_EPOCHS + 1):
    train_metrics = train_one_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        device,
        epoch=epoch,
        scaler=scaler,
        grad_accum_steps=GRAD_ACCUM_STEPS,
    )
    val_metrics = evaluate(model, val_loader, criterion, device)
    scheduler.step()
    history.append({
        "epoch": epoch,
        "train_loss": train_metrics["loss"],
        "train_acc": train_metrics["acc"],
        "val_loss": val_metrics["loss"],
        "val_acc": val_metrics["acc"],
    })
    print(
        f"Epoch {epoch}: "
        f"train_loss={train_metrics['loss']:.4f}, train_acc={train_metrics['acc']:.3f}, "
        f"val_loss={val_metrics['loss']:.4f}, val_acc={val_metrics['acc']:.3f}"
    )
    current_acc = val_metrics["acc"]
    if current_acc >= best_acc:
        best_acc = current_acc
        best_state = copy.deepcopy(model.state_dict())
        best_metrics = val_metrics
        torch.save(best_state, SAVE_BEST_MODEL_PATH)
        print(f"  -> New best model saved to {SAVE_BEST_MODEL_PATH}")

history_df = pd.DataFrame(history)
display(history_df)

if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Loaded best model with validation accuracy {best_acc:.3f}")
else:
    print("Warning: best model state was not captured; using final epoch weights.")


In [None]:
#cell 9
def load_clip_from_video(
    video_path: Path,
    frame_indices: Sequence[int],
    transform: transforms.Compose,
) -> torch.Tensor:
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Unable to open {video_path}")
    frames: List[torch.Tensor] = []
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
        success, frame = cap.read()
        if not success:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(transform(Image.fromarray(frame_rgb)))
    cap.release()
    if len(frames) != len(frame_indices):
        if frames:
            while len(frames) < len(frame_indices):
                frames.append(frames[-1])
        else:
            raise RuntimeError(f"Failed to read requested frames from {video_path}")
    return torch.stack(frames, dim=0)


@torch.no_grad()
def predict_video_timeline(
    model: nn.Module,
    record: VideoRecord,
    transform: transforms.Compose,
    clip_len: int,
    frame_step: int,
    clip_stride: int,
    device: torch.device,
    batch_size: int = 4,
) -> Dict[str, np.ndarray]:
    model.eval()
    video_path = record.media_path
    cap = cv2.VideoCapture(str(video_path))
    total_frames = record.total_frames
    fps = record.fps if record.fps else 25.0
    if cap.isOpened():
        total_frames_cap = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps_cap = cap.get(cv2.CAP_PROP_FPS)
        if total_frames_cap > 0:
            total_frames = total_frames_cap
        if fps_cap and fps_cap > 0:
            fps = fps_cap
    cap.release()
    if total_frames <= 0:
        raise RuntimeError(f"Unable to determine total frame count for {video_path}")

    clip_span = (clip_len - 1) * frame_step + 1
    if total_frames < clip_span:
        clip_start_positions = [0]
    else:
        clip_start_positions = list(range(0, total_frames - clip_span + 1, clip_stride))
        final_candidate = total_frames - clip_span
        if clip_start_positions[-1] != final_candidate:
            clip_start_positions.append(final_candidate)
    clip_indices = [list(range(start, min(total_frames, start + clip_span), frame_step))
                    for start in clip_start_positions]

    logits_batches: List[torch.Tensor] = []
    buffer: List[torch.Tensor] = []
    clip_tracker: List[int] = []

    for start, indices in zip(clip_start_positions, clip_indices):
        clip_tensor = load_clip_from_video(video_path, indices, transform)
        buffer.append(clip_tensor)
        clip_tracker.append(start)
        if len(buffer) == batch_size:
            batch = torch.stack(buffer, dim=0).to(device, non_blocking=True)
            outputs = model(batch)
            logits_batches.append(outputs.cpu())
            buffer.clear()
    if buffer:
        batch = torch.stack(buffer, dim=0).to(device, non_blocking=True)
        outputs = model(batch)
        logits_batches.append(outputs.cpu())
        buffer.clear()

    if not logits_batches:
        raise RuntimeError("No clips were generated for inference.")
    logits = torch.cat(logits_batches, dim=0)
    probs = torch.softmax(logits, dim=1).numpy()
    preds = probs.argmax(axis=1)

    frame_votes = np.zeros((total_frames, probs.shape[1]), dtype=np.float32)
    for start, indices, prob_vec in zip(clip_tracker, clip_indices, probs):
        for idx in indices:
            if idx < total_frames:
                frame_votes[idx] += prob_vec
    frame_labels = frame_votes.argmax(axis=1)
    frame_confidence = frame_votes.max(axis=1)

    return {
        "frame_labels": frame_labels,
        "frame_confidence": frame_confidence,
        "frame_votes": frame_votes,
        "clip_probs": probs,
        "clip_preds": preds,
        "clip_starts": np.array(clip_tracker),
        "fps": fps,
        "total_frames": total_frames,
    }


timeline_outputs = None
sample_record = val_records[0] if val_records else (train_records[0] if train_records else None)
if sample_record is None:
    print("No video records available for inference demo.")
else:
    timeline_outputs = predict_video_timeline(
        model=model,
        record=sample_record,
        transform=val_transform,
        clip_len=CLIP_LEN,
        frame_step=FRAME_STEP,
        clip_stride=INFERENCE_CLIP_STRIDE,
        device=device,
        batch_size=INFERENCE_BATCH_SIZE,
    )
    print(f"Inference complete for video '{sample_record.video_id}' with {timeline_outputs['total_frames']} frames.")


In [None]:
#cell 10
if best_metrics and best_metrics["preds"].size > 0:
    cm = confusion_matrix(
        best_metrics["targets"],
        best_metrics["preds"],
        labels=list(range(len(PHASE_NAMES))),
    )
    cm_df = pd.DataFrame(cm, index=PHASE_NAMES, columns=PHASE_NAMES)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
    plt.title("Validation Confusion Matrix")
    plt.ylabel("True Phase")
    plt.xlabel("Predicted Phase")
    plt.tight_layout()
    plt.show()

    report = classification_report(
        best_metrics["targets"],
        best_metrics["preds"],
        target_names=PHASE_NAMES,
        digits=3,
        zero_division=0,
    )
    print(report)
else:
    print("Confusion matrix is unavailable; validation predictions were empty.")

if sample_record is not None and timeline_outputs is not None:
    total_frames = timeline_outputs["total_frames"]
    pred_labels = timeline_outputs["frame_labels"]
    pred_conf = timeline_outputs["frame_confidence"]
    gt_timeline = build_ground_truth_timeline(sample_record)
    frame_axis = np.arange(total_frames)
    time_axis = frame_axis / max(1e-6, timeline_outputs["fps"])

    plt.figure(figsize=(16, 4))
    plt.step(time_axis, gt_timeline, where="post", label="Ground truth", linewidth=2, alpha=0.8)
    plt.step(time_axis, pred_labels, where="post", label="Predicted", linewidth=1.5, alpha=0.8)
    plt.yticks(range(len(PHASE_NAMES)), PHASE_NAMES)
    plt.xlabel("Time (s)")
    plt.ylabel("Phase")
    plt.title(f"Surgical phase timeline — {sample_record.video_id}")
    plt.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(16, 2))
    plt.plot(time_axis, pred_conf, label="Prediction confidence")
    plt.xlabel("Time (s)")
    plt.ylabel("Confidence")
    plt.ylim(0, 1.05)
    plt.title("Frame-level confidence")
    plt.tight_layout()
    plt.show()
else:
    print("Timeline visualization is unavailable; ensure inference ran successfully.")
