# Model and Hierarchy

In [None]:
from surgical_phase_tool.config_loader import load_config
from surgical_phase_tool.models.resnet_multitask import PhaseToolNet

cfg = load_config()
model = PhaseToolNet(backbone_name=cfg['model']['backbone'], pretrained=False)
model

In [None]:
import matplotlib.pyplot as plt

from surgical_phase_tool.hierarchy.phase_tool_mask import build_phase_tool_mask
from surgical_phase_tool.dataset import PHASE_TO_ID, TOOL_COLUMNS

mask = build_phase_tool_mask().numpy()
phases = list(PHASE_TO_ID.keys())
tools = list(TOOL_COLUMNS)

plt.figure(figsize=(8, 4))
plt.imshow(mask, aspect='auto', cmap='Greys')
plt.yticks(range(len(phases)), phases)
plt.xticks(range(len(tools)), tools, rotation=45, ha='right')
plt.xlabel('Tools')
plt.ylabel('Phases')
plt.title('Phaseâ†’tool validity mask')
plt.colorbar(label='validity')
plt.tight_layout()
plt.show()

```python
# Hard masking (training)
phase_ids = phase_target.argmax(dim=-1)            # (B,)
sample_mask = mask[phase_ids]                     # (B, K)
large_neg = -1e4
masked_logits = tool_logits * sample_mask + large_neg * (1 - sample_mask)

# Soft masking (evaluation)
phase_probs = phase_logits.softmax(dim=-1)        # (B, P)
sample_mask = phase_probs @ mask                  # (B, K)
masked_logits = tool_logits * sample_mask + large_neg * (1 - sample_mask)
probs_with_hierarchy = masked_logits.sigmoid()
```