# Final Evaluation

In [None]:
from importlib import reload

import surgical_phase_tool.evaluate as eval_module

reload(eval_module)
eval_module.evaluate_model()

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from surgical_phase_tool.config import (
    TEST_MANIFEST, BATCH_SIZE, NUM_WORKERS, DEVICE, CHECKPOINT_PATH, SEED
)
from surgical_phase_tool.config_loader import set_global_seed
from surgical_phase_tool.dataset import MultiTaskWindowDataset, PHASE_TO_ID, TOOL_COLUMNS, ID_TO_PHASE
from surgical_phase_tool.models.resnet_multitask import PhaseToolNet
from surgical_phase_tool.hierarchy.phase_tool_mask import build_phase_tool_mask, apply_phase_mask_to_logits
from surgical_phase_tool.metrics import phase_metrics, tool_metrics

set_global_seed(SEED)

test_dataset = MultiTaskWindowDataset(TEST_MANIFEST, is_train=False)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False
)

model = PhaseToolNet(backbone_name='resnet18', pretrained=False).to(DEVICE)
state = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(state)
model.eval()

phase_tool_mask = build_phase_tool_mask().to(DEVICE)

all_phase_logits, all_phase_targets = [], []
all_tool_logits, all_tool_targets = [], []
with torch.no_grad():
    for frames, phase_target, tool_target in test_loader:
        frames = frames.to(DEVICE)
        phase_target = phase_target.to(DEVICE)
        tool_target = tool_target.to(DEVICE)
        phase_logits, tool_logits = model(frames)
        all_phase_logits.append(phase_logits.cpu())
        all_phase_targets.append(phase_target.cpu())
        all_tool_logits.append(tool_logits.cpu())
        all_tool_targets.append(tool_target.cpu())

all_phase_logits = torch.cat(all_phase_logits, dim=0)
all_phase_targets = torch.cat(all_phase_targets, dim=0)
all_tool_logits = torch.cat(all_tool_logits, dim=0)
all_tool_targets = torch.cat(all_tool_targets, dim=0)

phase_overall_acc, phase_per_class_acc, cm = phase_metrics(all_phase_logits, all_phase_targets)
tool_stats_no_mask = tool_metrics(all_tool_logits, all_tool_targets, threshold=0.5)
phase_probs = all_phase_logits.softmax(dim=-1)
masked_logits = apply_phase_mask_to_logits(all_tool_logits, phase_probs, phase_tool_mask.to(all_tool_logits.device), hard=False)
tool_stats_with_mask = tool_metrics(masked_logits, all_tool_targets, threshold=0.5)

phase_overall_acc, tool_stats_no_mask['f1_micro'], tool_stats_with_mask['f1_micro']

In [None]:
phases = list(PHASE_TO_ID.keys())
print('Overall phase accuracy:', phase_overall_acc)
print()
print('Per-class phase accuracy:')
for p, idx in PHASE_TO_ID.items():
    print(f'{idx} - {p}: {phase_per_class_acc[idx]:.4f}')

plt.figure(figsize=(5, 4))
plt.imshow(cm, interpolation='nearest', cmap='Blues')
plt.title('Phase confusion matrix')
plt.colorbar()
tick_marks = np.arange(len(phases))
plt.xticks(tick_marks, phases, rotation=45, ha='right')
plt.yticks(tick_marks, phases)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.show()

In [None]:
print('Tool micro-F1 without hierarchy:', tool_stats_no_mask['f1_micro'])
print('Tool micro-F1 with hierarchy   :', tool_stats_with_mask['f1_micro'])

print('Per-tool F1 (no mask â†’ with mask):')
for i, tool in enumerate(TOOL_COLUMNS):
    f1_no = tool_stats_no_mask['f1_per'][i]
    f1_with = tool_stats_with_mask['f1_per'][i]
    print(f

labels = ['no hierarchy', 'with hierarchy']
values = [tool_stats_no_mask['f1_micro'], tool_stats_with_mask['f1_micro']]
plt.figure(figsize=(4, 4))
plt.bar(labels, values, color=['grey', 'green'])
plt.ylim(0.0, 1.0)
plt.ylabel('Tool micro-F1')
for i, v in enumerate(values):
    plt.text(i, v + 0.02, f'{v:.3f}', ha='center')
plt.tight_layout()
plt.show()

In [None]:
num_examples = 3
examples = []
for i in range(num_examples):
    frames, phase_target, tool_target = test_dataset[i]
    examples.append((frames.unsqueeze(0), phase_target, tool_target))

for idx, (frames, phase_target, tool_target) in enumerate(examples):
    with torch.no_grad():
        phase_logits, tool_logits = model(frames.to(DEVICE))
        phase_probs = phase_logits.softmax(dim=-1)
        tool_probs = tool_logits.sigmoid()
        masked_logits_ex = apply_phase_mask_to_logits(tool_logits, phase_probs, phase_tool_mask, hard=False)
        masked_probs = masked_logits_ex.sigmoid()
    true_phase_id = phase_target.argmax().item()
    pred_phase_id = phase_probs.argmax(dim=-1).item()
    print(f'Example {idx}: true={ID_TO_PHASE[true_phase_id]}, pred={ID_TO_PHASE[pred_phase_id]}')
    probs = tool_probs.squeeze(0).cpu().numpy()
    mprobs = masked_probs.squeeze(0).cpu().numpy()
    print('  Top tools (unmasked):')
    for i in probs.argsort()[::-1][:5]:
        print(f'    {TOOL_COLUMNS[i]}: {probs[i]:.3f}')
    print('  Top tools (masked):')
    for i in mprobs.argsort()[::-1][:5]:
        print(f'    {TOOL_COLUMNS[i]}: {mprobs[i]:.3f}')
    print()