# MS-TCN Prototype Training Notebook

This notebook implements a minimal MS-TCN-style training pipeline for a quick prototype on the `cataract-101` dataset.

It expects per-video pre-extracted features at `cataract-101/features/{video_id}.npy` and a `cataract-101/segments_filled.csv` file.

**What this notebook contains**:
- Configuration cell you can edit
- Dataset loader for fixed-length clips from pre-extracted features
- Lightweight single-stage MS-TCN model
- Training loop (run for a short smoke test)
- Single-video inference example

Run cells sequentially. The notebook **does not** run full training automatically; you control when to start training.


## Requirements

Install required packages if you haven't already:

```bash
pip install torch torchvision numpy pandas scikit-learn pyyaml opencv-python tqdm
```


In [None]:
# === Configuration ===
cfg = {
    "data": {
        "dataset_root": "cataract-101",
        "features_dir": "features",
        "segments_csv": "segments_filled.csv",
        "sample_size": 20,      # number of videos to use for quick prototype (set None to use all)
        "seq_len": 200,
        "fps": 25
    },
    "model": {
        "feat_dim": 512,       # set to your feature dim
        "hidden_dim": 64,
        "num_layers": 6,
        "num_classes": 10      # overwrite with true number of phases
    },
    "training": {
        "epochs": 3,           # small number for smoke test
        "batch_size": 2,
        "lr": 1e-3,
        "weight_decay": 1e-4,
        "device": "cuda" if __import__('torch').cuda.is_available() else "cpu",
        "seed": 42
    },
    "io": {
        "out_dir": "exp_ms_tcn_sample",
        "checkpoint_freq": 1
    }
}
print('Config ready - please edit values above to match your setup.')

In [None]:
# Imports
import os, random, yaml, math, time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
print('imports ok')

In [None]:
# Dataset class (expects features/{video_id}.npy and segments_filled.csv)
class ClipDataset(Dataset):
    def __init__(self, cfg, split='train'):
        self.root = cfg['data']['dataset_root']
        self.features_dir = os.path.join(self.root, cfg['data']['features_dir'])
        self.seq_len = cfg['data']['seq_len']
        self.sample_size = cfg['data'].get('sample_size', None)
        self.segments_csv = os.path.join(self.root, cfg['data']['segments_csv'])
        self.split = split

        seg_df = pd.read_csv(self.segments_csv)
        # build per-video label arrays
        self.video_labels = {}
        for vid, g in seg_df.groupby('video_id'):
            max_frame = int(g['end_frame'].max()) + 1
            labels = np.zeros((max_frame,), dtype=np.int64) - 1
            for _, row in g.iterrows():
                s = int(row['start_frame'])
                e = int(row['end_frame'])
                labels[s:e+1] = int(row['phase_id'])
            if labels.max() >= 0:
                self.video_labels[vid] = labels

        self.videos = list(self.video_labels.keys())
        if self.sample_size is not None:
            random.seed(cfg['training']['seed'])
            random.shuffle(self.videos)
            self.videos = self.videos[: self.sample_size]

        # map video->feature path (may be missing)
        self.video_feats = {}
        for vid in self.videos:
            feat_path = os.path.join(self.features_dir, f"{vid}.npy")
            self.video_feats[vid] = feat_path if os.path.exists(feat_path) else None

        # build index entries of windows
        self.index = []
        for vid in self.videos:
            n = len(self.video_labels[vid])
            if n < self.seq_len:
                continue
            for s in range(0, n - self.seq_len + 1, self.seq_len):
                self.index.append((vid, s))
        if len(self.index) == 0:
            for vid in self.videos:
                n = len(self.video_labels[vid])
                if n >= self.seq_len:
                    for s in range(0, n - self.seq_len + 1, max(1, self.seq_len // 2)):
                        self.index.append((vid, s))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        vid, s = self.index[idx]
        feat_path = self.video_feats[vid]
        labels = self.video_labels[vid]
        if feat_path is None:
            raise FileNotFoundError(f"Missing features for video {vid}. Expected {feat_path}")
        feats = np.load(feat_path)
        clip_feats = feats[s : s + self.seq_len]
        clip_lbls = labels[s : s + self.seq_len]
        return {
            'video_id': vid,
            'start': s,
            'feats': torch.from_numpy(clip_feats).float(),
            'labels': torch.from_numpy(clip_lbls).long()
        }

def collate_fn(batch):
    feats = torch.stack([b['feats'] for b in batch], dim=0)
    labels = torch.stack([b['labels'] for b in batch], dim=0)
    return {'feats': feats, 'labels': labels, 'video_id': [b['video_id'] for b in batch], 'start': [b['start'] for b in batch]}

print('Dataset class defined')

In [None]:
# Simple single-stage MS-TCN-like model
class DilatedResidualLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dilation):
        super().__init__()
        self.conv_dilated = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        if in_channels != out_channels:
            self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.downsample = None

    def forward(self, x):
        out = self.conv_dilated(x)
        out = self.relu(out)
        out = self.conv_1x1(out)
        out = self.dropout(out)
        if self.downsample is not None:
            x = self.downsample(x)
        return self.relu(out + x)

class SimpleMS_TCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.input_conv = nn.Conv1d(feat_dim, hidden_dim, kernel_size=1)
        layers = []
        for i in range(num_layers):
            dilation = 2 ** (i % 4)
            layers.append(DilatedResidualLayer(hidden_dim, hidden_dim, dilation=dilation))
        self.tcn = nn.Sequential(*layers)
        self.classifier = nn.Conv1d(hidden_dim, num_classes, kernel_size=1)

    def forward(self, feats):
        # feats: (B, T, D)
        x = feats.permute(0, 2, 1)  # -> (B, D, T)
        x = self.input_conv(x)
        x = self.tcn(x)
        out = self.classifier(x)
        out = out.permute(0, 2, 1)  # (B, T, C)
        return out

print('Model defined')

In [None]:
# Metrics helpers
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def compute_frame_metrics(gt, pred, num_classes):
    mask = gt != -1
    gt = gt[mask]; pred = pred[mask]
    acc = accuracy_score(gt, pred)
    p, r, f1, _ = precision_recall_fscore_support(gt, pred, labels=list(range(num_classes)), zero_division=0)
    macro_f1 = float(np.nanmean(f1))
    per_class = {f'p_{i}': float(p[i]) for i in range(len(p))}
    per_class.update({f'r_{i}': float(r[i]) for i in range(len(r))})
    per_class.update({f'f1_{i}': float(f1[i]) for i in range(len(f1))})
    res = {'acc': float(acc), 'macro_f1': macro_f1}
    res.update(per_class)
    return res

def collapse_sequence(x):
    out = []
    for v in x:
        if v == -1:
            continue
        if len(out)==0 or out[-1] != v:
            out.append(int(v))
    return out

def levenshtein(a, b):
    n, m = len(a), len(b)
    dp = np.zeros((n+1, m+1), dtype=int)
    for i in range(1, n+1): dp[i,0] = i
    for j in range(1, m+1): dp[0,j] = j
    for i in range(1, n+1):
        for j in range(1, m+1):
            cost = 0 if a[i-1]==b[j-1] else 1
            dp[i,j] = min(dp[i-1,j] + 1, dp[i,j-1] + 1, dp[i-1,j-1] + cost)
    return dp[n,m]

def sequence_edit_score(gt, pred):
    g = collapse_sequence(gt)
    p = collapse_sequence(pred)
    if len(g)==0:
        return 0.0
    dist = levenshtein(g,p)
    score = 1.0 - dist / max(len(g), len(p), 1)
    return float(score)

print('Metrics ready')

In [None]:
# Training and validation functions
import torch.optim as optim

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    for batch in tqdm(loader, desc='train'):
        feats = batch['feats'].to(device)
        labels = batch['labels'].to(device)
        logits = model(feats)
        B, L, C = logits.shape
        loss = criterion(logits.view(B*L, C), labels.view(B*L))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item())
    return total_loss / len(loader)

def validate(model, loader, device, num_classes):
    model.eval()
    preds_all = []
    labels_all = []
    with torch.no_grad():
        for batch in loader:
            feats = batch['feats'].to(device)
            labels = batch['labels'].numpy().reshape(-1)
            logits = model(feats)
            pred = logits.argmax(dim=-1).cpu().numpy().reshape(-1)
            preds_all.append(pred)
            labels_all.append(labels)
    preds = np.concatenate(preds_all)
    labels = np.concatenate(labels_all)
    mask = labels != -1
    labels = labels[mask]; preds = preds[mask]
    from sklearn.metrics import f1_score
    f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
    acc = (labels == preds).mean()
    return {'f1_macro': float(f1_macro), 'acc': float(acc)}

print('Train/val functions defined')

In [None]:
# Training runner (SMOKE-RUN)
# Edit cfg above before running. This cell will run training for cfg['training']['epochs'] epochs.
set_seed(cfg['training']['seed'])
os.makedirs(cfg['io']['out_dir'], exist_ok=True)
device = torch.device(cfg['training']['device'] if torch.cuda.is_available() else 'cpu')

# Prepare datasets
train_ds = ClipDataset(cfg, split='train')
val_ds = ClipDataset(cfg, split='val')
print(f'Found {len(train_ds.videos)} videos (sampled), {len(train_ds)} training clips.')

train_loader = DataLoader(train_ds, batch_size=cfg['training']['batch_size'], shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=cfg['training']['batch_size'], shuffle=False, collate_fn=collate_fn)

num_classes = cfg['model']['num_classes']
feat_dim = cfg['model']['feat_dim']

model = SimpleMS_TCN(feat_dim=feat_dim, hidden_dim=cfg['model']['hidden_dim'],
                     num_layers=cfg['model']['num_layers'], num_classes=num_classes).to(device)
optimizer = optim.AdamW(model.parameters(), lr=cfg['training']['lr'], weight_decay=cfg['training']['weight_decay'])

best_val = -1.0
for epoch in range(1, cfg['training']['epochs'] + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    metrics = validate(model, val_loader, device, num_classes)
    print(f"Epoch {epoch} train_loss={train_loss:.4f} val_f1={metrics['f1_macro']:.4f} val_acc={metrics['acc']:.4f}")
    ckpt_path = os.path.join(cfg['io']['out_dir'], f"ckpt_epoch{epoch}.pth")
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer': optimizer.state_dict()}, ckpt_path)
    if metrics['f1_macro'] > best_val:
        best_val = metrics['f1_macro']
        torch.save({'epoch': epoch, 'model_state': model.state_dict()}, os.path.join(cfg['io']['out_dir'], 'best.pth'))

print('Training complete (smoke-run). Check', cfg['io']['out_dir'])

In [None]:
# Single-video inference example (uses best.pth if available)
def load_model_from_ckpt(cfg, device, ckpt_path=None):
    num_classes = cfg['model']['num_classes']
    model = SimpleMS_TCN(feat_dim=cfg['model']['feat_dim'], hidden_dim=cfg['model']['hidden_dim'],
                         num_layers=cfg['model']['num_layers'], num_classes=num_classes).to(device)
    if ckpt_path is None:
        ckpt_path = os.path.join(cfg['io']['out_dir'], 'best.pth')
    if os.path.exists(ckpt_path):
        ck = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ck['model_state'])
        print('Loaded checkpoint', ckpt_path)
    else:
        print('No checkpoint found at', ckpt_path)
    model.eval()
    return model

def inference_full_video(model, features, device, seq_len=200, overlap=0.5):
    model.eval()
    T, D = features.shape
    step = int(seq_len * (1 - overlap))
    counts = np.zeros((T,), dtype=np.int32)
    logits_sum = np.zeros((T, model.classifier.out_channels), dtype=np.float32)
    with torch.no_grad():
        for s in range(0, max(1, T - seq_len + 1), max(1, step)):
            clip = torch.from_numpy(features[s:s+seq_len]).unsqueeze(0).to(device).float()
            out = model(clip)
            out = out.cpu().numpy()[0]
            logits_sum[s:s+seq_len] += out
            counts[s:s+seq_len] += 1
        if counts.sum() == 0:
            clip = torch.from_numpy(np.pad(features, ((0, seq_len - T), (0,0)))[None]).to(device).float()
            out = model(clip).cpu().numpy()[0][:T]
            logits_sum[:T] += out
            counts[:T] += 1
    logits_avg = logits_sum / counts[:, None]
    preds = logits_avg.argmax(axis=-1)
    return preds

# Example usage:
# model = load_model_from_ckpt(cfg, device)
# feats = np.load('cataract-101/features/<video_id>.npy')
# preds = inference_full_video(model, feats, device, seq_len=cfg['data']['seq_len'])
# print('Pred shape', preds.shape)
print('Inference cell ready')

## Notes & next steps

- Edit the config cell at the top to match your dataset paths, feature dim, and number of classes.
- If you don't have pre-extracted features, ask me and I will generate a ResNet feature-extraction cell that runs on video files or frames.
- For full TeCNO, we can extend this notebook to multi-stage refinement and temporal smoothing post-processing.

When ready, run the training cell to perform a smoke training run (3 epochs by default).