# Training and Validation (Illustrative)

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

from surgical_phase_tool.config import (
    TRAIN_MANIFEST, BATCH_SIZE, NUM_WORKERS, DEVICE, LEARNING_RATE, WEIGHT_DECAY,
    PHASE_LOSS_WEIGHT, TOOL_LOSS_WEIGHT, SEED
)
from surgical_phase_tool.config_loader import set_global_seed
from surgical_phase_tool.dataset import MultiTaskWindowDataset, PHASE_TO_ID, TOOL_COLUMNS, ID_TO_PHASE
from surgical_phase_tool.models.resnet_multitask import PhaseToolNet
from surgical_phase_tool.hierarchy.phase_tool_mask import build_phase_tool_mask, apply_phase_mask_to_logits
from surgical_phase_tool.metrics import phase_metrics, tool_metrics

set_global_seed(SEED)

full_train_dataset = MultiTaskWindowDataset(TRAIN_MANIFEST, is_train=True)
val_size = max(1, int(0.1 * len(full_train_dataset)))
train_size = len(full_train_dataset) - val_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)

model = PhaseToolNet(backbone_name='resnet18', pretrained=True).to(DEVICE)

phase_counts = np.zeros(len(PHASE_TO_ID), dtype=np.int64)
for _, phase_target, _ in DataLoader(full_train_dataset, batch_size=64, shuffle=False, num_workers=NUM_WORKERS):
    ids = phase_target.argmax(dim=-1).numpy()
    for i in ids:
        phase_counts[i] += 1
phase_counts = np.maximum(phase_counts, 1)
inv_freq = 1.0 / phase_counts
weights = inv_freq / inv_freq.mean()
phase_class_weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE)

phase_criterion = nn.CrossEntropyLoss(weight=phase_class_weights)
tool_criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

phase_tool_mask = build_phase_tool_mask().to(DEVICE)

history = {
    'epoch': [],
    'train_loss': [],
    'val_loss': [],
    'val_phase_acc': [],
    'val_tool_micro_no_mask': [],
    'val_tool_micro_with_mask': [],
}
short_epochs = 3
len(train_dataset), len(val_dataset), short_epochs

In [None]:
for epoch in range(1, short_epochs + 1):
    model.train()
    running_loss = 0.0
    for frames, phase_target, tool_target in train_loader:
        frames = frames.to(DEVICE)
        phase_target = phase_target.to(DEVICE)
        tool_target = tool_target.to(DEVICE)
        optimizer.zero_grad()
        phase_logits, tool_logits = model(frames)
        phase_loss = phase_criterion(phase_logits, phase_target.argmax(dim=-1))
        gt_phase_ids = phase_target.argmax(dim=-1)
        gt_phase_one_hot = torch.nn.functional.one_hot(gt_phase_ids, num_classes=len(PHASE_TO_ID)).float()
        masked_tool_logits = apply_phase_mask_to_logits(tool_logits, gt_phase_one_hot, phase_tool_mask, hard=True)
        tool_loss = tool_criterion(masked_tool_logits, tool_target)
        loss = PHASE_LOSS_WEIGHT * phase_loss + TOOL_LOSS_WEIGHT * tool_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * frames.size(0)
    train_loss = running_loss / len(train_loader.dataset)

    model.eval()
    all_phase_logits, all_phase_targets = [], []
    all_tool_logits, all_tool_targets = [], []
    total_val_loss = 0.0
    with torch.no_grad():
        for frames, phase_target, tool_target in val_loader:
            frames = frames.to(DEVICE)
            phase_target = phase_target.to(DEVICE)
            tool_target = tool_target.to(DEVICE)
            phase_logits, tool_logits = model(frames)
            phase_loss = phase_criterion(phase_logits, phase_target.argmax(dim=-1))
            gt_phase_ids = phase_target.argmax(dim=-1)
            gt_phase_one_hot = torch.nn.functional.one_hot(gt_phase_ids, num_classes=len(PHASE_TO_ID)).float()
            masked_tool_logits = apply_phase_mask_to_logits(tool_logits, gt_phase_one_hot, phase_tool_mask, hard=True)
            tool_loss = tool_criterion(masked_tool_logits, tool_target)
            loss = PHASE_LOSS_WEIGHT * phase_loss + TOOL_LOSS_WEIGHT * tool_loss
            total_val_loss += loss.item() * frames.size(0)
            all_phase_logits.append(phase_logits.cpu())
            all_phase_targets.append(phase_target.cpu())
            all_tool_logits.append(tool_logits.cpu())
            all_tool_targets.append(tool_target.cpu())
    all_phase_logits = torch.cat(all_phase_logits, dim=0)
    all_phase_targets = torch.cat(all_phase_targets, dim=0)
    all_tool_logits = torch.cat(all_tool_logits, dim=0)
    all_tool_targets = torch.cat(all_tool_targets, dim=0)
    phase_acc, _, _ = phase_metrics(all_phase_logits, all_phase_targets)
    tool_no_mask = tool_metrics(all_tool_logits, all_tool_targets, threshold=0.5)
    phase_probs = all_phase_logits.softmax(dim=-1)
    masked_logits_pred = apply_phase_mask_to_logits(all_tool_logits, phase_probs, phase_tool_mask.to(all_tool_logits.device), hard=False)
    tool_with_mask = tool_metrics(masked_logits_pred, all_tool_targets, threshold=0.5)
    val_loss = total_val_loss / len(val_loader.dataset)

    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_phase_acc'].append(phase_acc)
    history['val_tool_micro_no_mask'].append(tool_no_mask['f1_micro'])
    history['val_tool_micro_with_mask'].append(tool_with_mask['f1_micro'])

history

In [None]:
epochs = history['epoch']
plt.figure(figsize=(10, 3))
plt.subplot(1, 3, 1)
plt.plot(epochs, history['train_loss'], label='train')
plt.plot(epochs, history['val_loss'], label='val')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Loss')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(epochs, history['val_phase_acc'], marker='o')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Phase accuracy (val)')
plt.subplot(1, 3, 3)
plt.plot(epochs, history['val_tool_micro_no_mask'], label='no mask', marker='o')
plt.plot(epochs, history['val_tool_micro_with_mask'], label='with mask', marker='o')
plt.xlabel('epoch')
plt.ylabel('micro-F1')
plt.title('Tool micro-F1 (val)')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
num_examples = 3
examples = []
for i in range(num_examples):
    frames, phase_target, tool_target = val_dataset[i]
    examples.append((frames.unsqueeze(0), phase_target, tool_target))

for idx, (frames, phase_target, tool_target) in enumerate(examples):
    with torch.no_grad():
        phase_logits, tool_logits = model(frames.to(DEVICE))
        phase_probs = phase_logits.softmax(dim=-1)
        tool_probs = tool_logits.sigmoid()
        masked_logits_ex = apply_phase_mask_to_logits(tool_logits, phase_probs, phase_tool_mask, hard=False)
        masked_probs = masked_logits_ex.sigmoid()
    true_phase_id = phase_target.argmax().item()
    pred_phase_id = phase_probs.argmax(dim=-1).item()
    print(f'Example {idx}: true={ID_TO_PHASE[true_phase_id]}, pred={ID_TO_PHASE[pred_phase_id]}')
    probs = tool_probs.squeeze(0).cpu().numpy()
    mprobs = masked_probs.squeeze(0).cpu().numpy()
    print('  Top tools (unmasked):')
    for i in probs.argsort()[::-1][:5]:
        print(f'    {TOOL_COLUMNS[i]}: {probs[i]:.3f}')
    print('  Top tools (masked):')
    for i in mprobs.argsort()[::-1][:5]:
        print(f'    {TOOL_COLUMNS[i]}: {mprobs[i]:.3f}')
    print()