# Full MS-TCN Prototype Notebook

Contains: preprocessing (build segments), ResNet feature extraction (from videos), dataset, model (MS-TCN), training, and evaluation cells.

**Paths**: change `BASE_DIR` and `WORK_DIR` in the config cell to match your environment (Kaggle defaults provided).

## Requirements

Install packages if missing:

```bash
pip install torch torchvision numpy pandas scikit-learn pyyaml opencv-python tqdm nbformat
```


In [None]:
# === Configuration ===
BASE_DIR = "/kaggle/input/cataract-101/cataract-101"   # where annotations.csv and videos/ live (Kaggle)
WORK_DIR = "/kaggle/working/cataract-101-generated"   # where we will write segments, features, outputs
FPS = 25
FEATURE_BACKBONE = "resnet50"   # resnet18 or resnet50
PRETRAINED_BACKBONE = True
BATCH_FRAME = 64   # frames per batch for feature extraction
IMG_SIZE = 224
SEQ_LEN = 200
SAMPLE_VIDEOS = 30   # None to use all videos
NUM_CLASSES = None   # set later after reading phases / annotations
DEVICE = "cuda" if __import__('torch').cuda.is_available() else "cpu"
PRINT_DEBUG = True

import os
os.makedirs(WORK_DIR, exist_ok=True)
os.makedirs(os.path.join(WORK_DIR, "features"), exist_ok=True)
os.makedirs(os.path.join(WORK_DIR, "labels_npy"), exist_ok=True)
os.makedirs(os.path.join(WORK_DIR, "checkpoints"), exist_ok=True)
print('Config:', BASE_DIR, '->', WORK_DIR, 'device=', DEVICE)

In [None]:
# Imports
import os, sys, math, time, shutil, json, random
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
print('imports ok')

In [None]:
# === STEP 1: Build segments_filled.csv and per-video label arrays ===
from pathlib import Path
ann_path = Path(BASE_DIR) / "annotations.csv"
phases_path = Path(BASE_DIR) / "phases.csv"
videos_dir = Path(BASE_DIR) / "videos"
if not ann_path.exists():
    raise FileNotFoundError(f"annotations.csv not found at {ann_path}")

ann = pd.read_csv(ann_path)
print('annotations columns:', list(ann.columns))

# load phases
phase_name_to_id = {}
if phases_path.exists():
    phases_df = pd.read_csv(phases_path)
    if phases_df.shape[1] == 2:
        id_col = phases_df.columns[0]; name_col = phases_df.columns[1]
        for _, r in phases_df.iterrows():
            phase_name_to_id[str(r[name_col])] = int(r[id_col])
    else:
        for i, r in phases_df.iterrows():
            phase_name_to_id[str(r.iloc[-1])] = i
    print('Loaded phases mapping from phases.csv:', phase_name_to_id)
else:
    print('No phases.csv found; will infer phase ids from annotations.')

def find_col(df, keys):
    for k in keys:
        for c in df.columns:
            if k in c.lower():
                return c
    return None

col_video = find_col(ann, ['video','video_id','videoid','vid','file','filename'])
col_phase = find_col(ann, ['phase','label','phase_id','phaseid','class','action','phase_name'])
col_start = find_col(ann, ['start_frame','startframe','start_f','start'])
col_end = find_col(ann, ['end_frame','endframe','end_f','end'])
col_frame = find_col(ann, ['frame','frame_id','frameid'])
col_start_time = find_col(ann, ['start_time','startsec','start_s','startseconds'])
col_end_time = find_col(ann, ['end_time','endsec','end_s','endseconds'])

print('Detected columns:', dict(video=col_video, phase=col_phase, start=col_start, end=col_end, frame=col_frame,
                               start_time=col_start_time, end_time=col_end_time))

segments = []
if col_start is not None and col_end is not None:
    for _, r in ann.iterrows():
        vid = str(r[col_video])
        phase_val = r[col_phase] if col_phase is not None else 0
        start_v = r[col_start]; end_v = r[col_end]
        def to_frame(x):
            if pd.isnull(x): return None
            try:
                xf = float(x)
            except:
                s = str(x)
                if ':' in s:
                    parts = [float(p) for p in s.split(':')]
                    if len(parts) == 3:
                        secs = parts[0]*3600 + parts[1]*60 + parts[2]
                    elif len(parts) == 2:
                        secs = parts[0]*60 + parts[1]
                    else:
                        secs = parts[0]
                    return int(round(secs * FPS))
                return None
            if abs(xf - round(xf)) < 1e-6 and xf >= 0 and xf > 100:
                return int(round(xf))
            if xf <= 20:
                return int(round(xf * FPS))
            return int(round(xf))
        sfrm = to_frame(start_v); efrm = to_frame(end_v)
        if sfrm is None or efrm is None: continue
        if phase_name_to_id:
            key = str(phase_val)
            if key in phase_name_to_id:
                pid = phase_name_to_id[key]
            else:
                try:
                    pid = int(phase_val)
                except:
                    pid = len(phase_name_to_id)
                    phase_name_to_id[key] = pid
        else:
            try:
                pid = int(phase_val)
            except:
                key = str(phase_val)
                if key not in phase_name_to_id:
                    phase_name_to_id[key] = len(phase_name_to_id)
                pid = phase_name_to_id[key]
        segments.append({'video_id': vid, 'start_frame': int(sfrm), 'end_frame': int(efrm), 'phase_id': int(pid)})
elif col_frame is not None and col_phase is not None:
    grouped = ann.groupby(col_video)
    for vid, g in grouped:
        g_sorted = g.sort_values(by=col_frame)
        frames = g_sorted[col_frame].astype(int).values
        phases = g_sorted[col_phase].astype(str).values
        if len(frames)==0: continue
        curr_phase = phases[0]; curr_start = frames[0]; prev = frames[0]
        for f, ph in zip(frames[1:], phases[1:]):
            if ph != curr_phase or f != prev + 1:
                if curr_phase not in phase_name_to_id:
                    phase_name_to_id[curr_phase] = len(phase_name_to_id)
                pid = phase_name_to_id[curr_phase]
                segments.append({'video_id': str(vid), 'start_frame': int(curr_start), 'end_frame': int(prev), 'phase_id': int(pid)})
                curr_phase = ph; curr_start = f
            prev = f
        if curr_phase not in phase_name_to_id:
            phase_name_to_id[curr_phase] = len(phase_name_to_id)
        pid = phase_name_to_id[curr_phase]
        segments.append({'video_id': str(vid), 'start_frame': int(curr_start), 'end_frame': int(prev), 'phase_id': int(pid)})
else:
    raise ValueError('Could not detect usable annotation format in annotations.csv. Inspect columns and adapt script.')

seg_df = pd.DataFrame(segments)
seg_df = seg_df[seg_df['end_frame'] >= seg_df['start_frame']].reset_index(drop=True)
seg_out = Path(WORK_DIR) / 'segments_filled.csv'
seg_df.to_csv(seg_out, index=False)
print('Saved', seg_out, 'segments:', len(seg_df), 'videos:', seg_df.video_id.nunique())

# build per-video label arrays
labels_dir = Path(WORK_DIR) / 'labels_npy'
labels_dir.mkdir(parents=True, exist_ok=True)
for vid, g in seg_df.groupby('video_id'):
    max_frame = int(g['end_frame'].max()) + 1
    labels = np.full((max_frame,), -1, dtype=np.int32)
    for _, r in g.iterrows():
        s = int(r['start_frame']); e = int(r['end_frame'])
        labels[s:e+1] = int(r['phase_id'])
    np.save(labels_dir / f'{vid}.npy', labels)
print('Saved labels_npy for', len(list(labels_dir.glob('*.npy'))), 'videos at', labels_dir)

if globals().get('NUM_CLASSES', None) is None:
    used_ids = sorted(seg_df['phase_id'].unique().tolist())
    NUM = int(max(used_ids) + 1) if len(used_ids)>0 else 1
    globals()['NUM_CLASSES'] = NUM
print('NUM_CLASSES set to', globals()['NUM_CLASSES'])

In [None]:
# === STEP 2: ResNet feature extraction from videos -> features/{video_id}.npy ===
from pathlib import Path
videos_csv = Path(BASE_DIR) / 'videos.csv'
videoid_to_file = {}
if videos_csv.exists():
    vids_df = pd.read_csv(videos_csv)
    cols = [c.lower() for c in vids_df.columns]
    vidcol = None; filecol = None
    for i,c in enumerate(cols):
        if 'video' in c or 'id' in c or 'name' in c:
            vidcol = vids_df.columns[i]; break
    for i,c in enumerate(cols):
        if 'file' in c or 'path' in c or 'video' in c:
            filecol = vids_df.columns[i]; break
    if vidcol is None:
        vidcol = vids_df.columns[0]
    if filecol is None:
        filecol = vids_df.columns[-1]
    for _, r in vids_df.iterrows():
        videoid_to_file[str(r[vidcol])] = str(r[filecol])
    if PRINT_DEBUG:
        print('Loaded videos.csv mapping for', len(videoid_to_file), 'entries')
else:
    print('No videos.csv found; will try to infer filenames from videos folder')

video_files = {}
for p in Path(BASE_DIR).joinpath('videos').iterdir():
    if p.is_file():
        stem = p.stem
        video_files[stem] = str(p)

all_video_ids = sorted([p.stem for p in Path(WORK_DIR).joinpath('labels_npy').glob('*.npy')])
print('label videos:', all_video_ids[:5], '... total', len(all_video_ids))

for vid in all_video_ids:
    if vid in videoid_to_file:
        path = Path(BASE_DIR) / 'videos' / videoid_to_file[vid]
        if path.exists():
            video_files[vid] = str(path)
    if vid not in video_files:
        for ext in ['.mp4', '.avi', '.mov', '.mkv']:
            cand = Path(BASE_DIR) / 'videos' / (vid + ext)
            if cand.exists():
                video_files[vid] = str(cand)
                break

print('Found video files for', len(video_files), 'videos (of', len(all_video_ids), 'labelled videos)')

# build backbone
print('Building backbone:', FEATURE_BACKBONE, 'pretrained=', PRETRAINED_BACKBONE)
if FEATURE_BACKBONE.lower().startswith('resnet'):
    if FEATURE_BACKBONE=='resnet50':
        backbone = models.resnet50(pretrained=PRETRAINED_BACKBONE)
    else:
        backbone = models.resnet18(pretrained=PRETRAINED_BACKBONE)
    modules = list(backbone.children())[:-1]
    backbone = nn.Sequential(*modules)
    backbone.eval()
    backbone.to(DEVICE)
    feat_dim = 2048 if FEATURE_BACKBONE=='resnet50' else 512
else:
    raise ValueError('Only resnet backbones supported in this notebook')
print('Backbone ready, feat_dim=', feat_dim)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def extract_features_from_video(video_path, out_path, label_len=None, batch_frames=BATCH_FRAME):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        print('Failed to open', video_path); return False
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()
    if len(frames)==0:
        print('No frames in', video_path); return False
    if label_len is not None and len(frames) > label_len:
        frames = frames[:label_len]
    feats = []
    with torch.no_grad():
        for i in range(0, len(frames), batch_frames):
            batch = frames[i:i+batch_frames]
            tensor_batch = torch.stack([transform(f) for f in batch]).to(DEVICE)
            out = backbone(tensor_batch)
            out = out.view(out.size(0), -1).cpu().numpy()
            feats.append(out)
    feats = np.concatenate(feats, axis=0)
    np.save(out_path, feats)
    return True

features_dir = Path(WORK_DIR) / 'features'
features_dir.mkdir(parents=True, exist_ok=True)
count_done = 0
for vid in tqdm(all_video_ids[:SAMPLE_VIDEOS] if SAMPLE_VIDEOS is not None else all_video_ids):
    outp = features_dir / f'{vid}.npy'
    if outp.exists():
        count_done += 1
        continue
    if vid not in video_files:
        print('No source video file found for', vid, '; skipping feature extraction')
        continue
    lab = np.load(Path(WORK_DIR) / 'labels_npy' / f'{vid}.npy')
    label_len = lab.shape[0]
    ok = extract_features_from_video(video_files[vid], outp, label_len=label_len, batch_frames=BATCH_FRAME)
    if ok:
        count_done += 1
print('Feature extraction done for', count_done, 'videos. Saved to', features_dir)

In [None]:
# === STEP 3: Dataset, model, training & eval ===
# Dataset
class ClipDataset(Dataset):
    def __init__(self, work_dir, seq_len=SEQ_LEN, sample_videos=SAMPLE_VIDEOS, split='train', seed=42):
        self.work_dir = Path(work_dir)
        self.features_dir = self.work_dir / 'features'
        self.labels_dir = self.work_dir / 'labels_npy'
        self.segments_csv = self.work_dir / 'segments_filled.csv'
        self.seq_len = seq_len
        self.sample_videos = sample_videos
        self.seed = seed

        seg_df = pd.read_csv(self.segments_csv)
        self.video_labels = {}
        for vid, g in seg_df.groupby('video_id'):
            max_frame = int(g['end_frame'].max()) + 1
            labels = np.full((max_frame,), -1, dtype=np.int32)
            for _, r in g.iterrows():
                s = int(r['start_frame']); e = int(r['end_frame'])
                labels[s:e+1] = int(r['phase_id'])
            self.video_labels[str(vid)] = labels

        self.videos = sorted(list(self.video_labels.keys()))
        random.seed(self.seed); random.shuffle(self.videos)
        if self.sample_videos is not None:
            self.videos = self.videos[:self.sample_videos]

        self.index = []
        for vid in self.videos:
            n = len(self.video_labels[vid])
            if n < self.seq_len: continue
            for s in range(0, n - self.seq_len + 1, self.seq_len):
                self.index.append((vid, s))
        if len(self.index)==0:
            for vid in self.videos:
                n = len(self.video_labels[vid])
                if n >= self.seq_len:
                    for s in range(0, n - self.seq_len + 1, max(1, self.seq_len//2)):
                        self.index.append((vid, s))

    def __len__(self): return len(self.index)
    def __getitem__(self, idx):
        vid, s = self.index[idx]
        feat_path = self.features_dir / f'{vid}.npy'
        if not feat_path.exists():
            raise FileNotFoundError(f'Missing features for {vid} at {feat_path}')
        feats = np.load(feat_path)
        clip_feats = feats[s:s+self.seq_len]
        clip_lbls = self.video_labels[vid][s:s+self.seq_len]
        return {'video_id': vid, 'start': s, 'feats': torch.from_numpy(clip_feats).float(), 'labels': torch.from_numpy(clip_lbls).long()}

def collate_fn(batch):
    feats = torch.stack([b['feats'] for b in batch], dim=0)
    labels = torch.stack([b['labels'] for b in batch], dim=0)
    return {'feats': feats, 'labels': labels, 'video_id': [b['video_id'] for b in batch], 'start': [b['start'] for b in batch]}

# Model
class DilatedResidualLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dilation):
        super().__init__()
        self.conv_dilated = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        if in_channels != out_channels:
            self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.downsample = None
    def forward(self, x):
        out = self.conv_dilated(x)
        out = self.relu(out)
        out = self.conv_1x1(out)
        out = self.dropout(out)
        if self.downsample is not None:
            x = self.downsample(x)
        return self.relu(out + x)

class SimpleMS_TCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.input_conv = nn.Conv1d(feat_dim, hidden_dim, kernel_size=1)
        layers = []
        for i in range(num_layers):
            dilation = 2 ** (i % 4)
            layers.append(DilatedResidualLayer(hidden_dim, hidden_dim, dilation=dilation))
        self.tcn = nn.Sequential(*layers)
        self.classifier = nn.Conv1d(hidden_dim, num_classes, kernel_size=1)
    def forward(self, feats):
        x = feats.permute(0,2,1)
        x = self.input_conv(x)
        x = self.tcn(x)
        out = self.classifier(x)
        out = out.permute(0,2,1)
        return out

# Metrics helpers
def compute_frame_metrics(gt, pred, num_classes):
    mask = gt != -1
    gt = gt[mask]; pred = pred[mask]
    acc = float((gt==pred).mean())
    p, r, f1, _ = precision_recall_fscore_support(gt, pred, labels=list(range(num_classes)), zero_division=0)
    macro_f1 = float(np.nanmean(f1))
    return {'acc': acc, 'macro_f1': macro_f1}

def collapse_sequence(x):
    out=[]
    for v in x:
        if v==-1: continue
        if len(out)==0 or out[-1]!=v: out.append(int(v))
    return out

def levenshtein(a,b):
    n,m = len(a), len(b)
    dp = np.zeros((n+1,m+1), dtype=int)
    for i in range(1,n+1): dp[i,0]=i
    for j in range(1,m+1): dp[0,j]=j
    for i in range(1,n+1):
        for j in range(1,m+1):
            cost = 0 if a[i-1]==b[j-1] else 1
            dp[i,j] = min(dp[i-1,j]+1, dp[i,j-1]+1, dp[i-1,j-1]+cost)
    return dp[n,m]

def sequence_edit_score(gt,pred):
    g = collapse_sequence(gt); p = collapse_sequence(pred)
    if len(g)==0: return 0.0
    dist = levenshtein(g,p)
    score = 1.0 - dist / max(len(g), len(p), 1)
    return float(score)

print('Dataset+Model+Metrics ready')

In [None]:
# === STEP 4: Train (SMOKE RUN) ===
work_dir = WORK_DIR
seq_len = SEQ_LEN
batch_size = 2
epochs = 3
lr = 1e-3
weight_decay = 1e-4
hidden_dim = 64
num_layers = 6
num_classes = globals().get('NUM_CLASSES', 10)

device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
print('Device:', device)

train_ds = ClipDataset(work_dir, seq_len=seq_len, sample_videos=SAMPLE_VIDEOS, split='train', seed=42)
val_ds = ClipDataset(work_dir, seq_len=seq_len, sample_videos=max(1, min(10, SAMPLE_VIDEOS//3 or 1)), split='val', seed=999)
print('Train videos:', len(train_ds.videos), 'Train clips:', len(train_ds))
print('Val videos:', len(val_ds.videos), 'Val clips:', len(val_ds))

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

first_feat = list(Path(work_dir).joinpath('features').glob('*.npy'))[0]
feat_dim = np.load(first_feat).shape[1]
print('Detected feat_dim=', feat_dim)

model = SimpleMS_TCN(feat_dim=feat_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss(ignore_index=-1)

best_val = -1.0
for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f'train-epoch{epoch}'):
        feats = batch['feats'].to(device)
        labels = batch['labels'].to(device)
        logits = model(feats)
        B,L,C = logits.shape
        loss = criterion(logits.view(B*L,C), labels.view(B*L))
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += float(loss.item())
    avg_loss = total_loss / len(train_loader) if len(train_loader)>0 else 0.0
    model.eval()
    preds_all=[]; labels_all=[]
    with torch.no_grad():
        for batch in val_loader:
            feats = batch['feats'].to(device)
            labels = batch['labels'].numpy().reshape(-1)
            logits = model(feats)
            pred = logits.argmax(dim=-1).cpu().numpy().reshape(-1)
            preds_all.append(pred); labels_all.append(labels)
    if len(preds_all)>0:
        preds = np.concatenate(preds_all); labels = np.concatenate(labels_all)
        mask = labels != -1
        val_f1 = 0.0; val_acc = 0.0
        if mask.sum()>0:
            from sklearn.metrics import f1_score
            val_f1 = f1_score(labels[mask], preds[mask], average='macro', zero_division=0)
            val_acc = float((labels[mask]==preds[mask]).mean())
    else:
        val_f1 = 0.0; val_acc = 0.0
    print(f'Epoch {epoch}: train_loss={avg_loss:.4f} val_f1={val_f1:.4f} val_acc={val_acc:.4f}')
    ckpt = {'epoch': epoch, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}
    ckpt_path = Path(work_dir) / 'checkpoints' / f'ckpt_epoch{epoch}.pth'
    torch.save(ckpt, ckpt_path)
    if val_f1 > best_val:
        best_val = val_f1
        torch.save(ckpt, Path(work_dir) / 'checkpoints' / 'best.pth')

print('Training finished. Best val f1=', best_val)

In [None]:
# === STEP 5: Full-video inference + metrics ===
def load_model(checkpoint_path, device):
    ck = torch.load(checkpoint_path, map_location=device)
    m = SimpleMS_TCN(feat_dim=feat_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_classes=num_classes).to(device)
    m.load_state_dict(ck['model_state'])
    m.eval()
    return m

def inference_full_video(model, features, device, seq_len=SEQ_LEN, overlap=0.5):
    T,D = features.shape
    step = int(seq_len * (1 - overlap))
    counts = np.zeros((T,), dtype=np.int32)
    logits_sum = np.zeros((T, num_classes), dtype=np.float32)
    with torch.no_grad():
        for s in range(0, max(1, T - seq_len + 1), max(1, step)):
            clip = torch.from_numpy(features[s:s+seq_len]).unsqueeze(0).to(device).float()
            out = model(clip).cpu().numpy()[0]
            L = out.shape[0]
            logits_sum[s:s+L] += out
            counts[s:s+L] += 1
        if counts.sum() == 0:
            clip = torch.from_numpy(np.pad(features, ((0, seq_len-T), (0,0)))[None]).to(device).float()
            out = model(clip).cpu().numpy()[0][:T]
            logits_sum[:T] += out
            counts[:T] += 1
    logits_avg = logits_sum / counts[:, None]
    preds = logits_avg.argmax(axis=-1)
    return preds

best_ckpt = Path(WORK_DIR) / 'checkpoints' / 'best.pth'
if not best_ckpt.exists():
    ckpts = sorted(Path(WORK_DIR).joinpath('checkpoints').glob('ckpt_epoch*.pth'))
    if len(ckpts)>0:
        best_ckpt = ckpts[-1]
    else:
        raise FileNotFoundError('No checkpoint found in', Path(WORK_DIR).joinpath('checkpoints'))

# load model
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
model = load_model(str(best_ckpt), device)
features_dir = Path(WORK_DIR) / 'features'
labels_dir = Path(WORK_DIR) / 'labels_npy'
out_preds_dir = Path(WORK_DIR) / 'preds_npy'
out_preds_dir.mkdir(parents=True, exist_ok=True)

rows = []
for feat_file in sorted(features_dir.glob('*.npy')):
    vid = feat_file.stem
    feats = np.load(feat_file)
    preds = inference_full_video(model, feats, device, seq_len=SEQ_LEN)
    np.save(out_preds_dir / f'{vid}_pred.npy', preds)
    gt_path = labels_dir / f'{vid}.npy'
    if gt_path.exists():
        gt = np.load(gt_path)
        L = min(len(gt), len(preds))
        fm = compute_frame_metrics(gt[:L], preds[:L], num_classes)
        edit = sequence_edit_score(gt[:L], preds[:L])
        rows.append({'video_id': vid, 'acc': fm['acc'], 'macro_f1': fm['macro_f1'], 'edit': edit})
    else:
        rows.append({'video_id': vid, 'acc': None, 'macro_f1': None, 'edit': None})

df = pd.DataFrame(rows)
df.to_csv(Path(WORK_DIR) / 'inference_metrics.csv', index=False)
print('Saved inference metrics to', Path(WORK_DIR) / 'inference_metrics.csv')

## Done

Edit the configuration cell and run cells sequentially. If you want a smaller notebook without feature extraction (assumes you already have features), tell me and I'll generate that variant.