In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
from tqdm import tqdm
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Get repository root directory
# Try multiple methods to find the repo root
cwd = Path.cwd()
if (cwd / 'data').exists():
    REPO_ROOT = cwd
elif (cwd.parent / 'data').exists():
    REPO_ROOT = cwd.parent
else:
    # Fallback: assume we're in dev_notebooks and go up one level
    REPO_ROOT = cwd.parent

import imports

Using device: cuda


In [2]:
# Dynamic expet settings:

# 1. Train on all tools, test on all tools
# 2. Train on some tools, test on rest of tools
# 3. Mixture of tools seen/unseen

In [3]:
# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
DATA_ROOT = REPO_ROOT / "data" / "openi"
LABELS_DIR = DATA_ROOT / "labels"
IMAGES_DIR = DATA_ROOT / "image"
PRED_DIR   = DATA_ROOT / "predictions"

# ------------------------------------------------------------
# Labels (tasks)
# ------------------------------------------------------------
label_names = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]
num_tasks = len(label_names)

# ------------------------------------------------------------
# Tool registry
# ------------------------------------------------------------
registry_all = imports.scan_prediction_files(str(PRED_DIR))

# Example split: train on non-resnet tools
train_tools = [t for t in registry_all["train"]]

train_registry = {t: registry_all["train"][t] for t in train_tools}
val_registry   = {t: registry_all["val"][t]   for t in train_tools}
test_registry  = {t: registry_all["test"][t]  for t in train_tools}


# ------------------------------------------------------------
# Datasets
# ------------------------------------------------------------
train_dataset_full = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Train.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=train_registry,
    label_names=label_names,
    transform=None,  # assume tensor conversion inside dataset
)

val_dataset = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Valid.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=val_registry,
    label_names=label_names,
    transform=None,
)

te_dataset = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Test.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=test_registry,
    label_names=label_names,
    transform=None,
)


In [4]:
ctx_mgr = imports.ContextManager(
    dataset=train_dataset_full,
    context_fraction=0.1,      # 10% context
    examples_per_tool=32,      # B_t
)

train_dataset = ctx_mgr.routing_dataset()


In [5]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)

test_loader = DataLoader(
    te_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


In [None]:
model = imports.DySTANceRouter(
    num_tasks=num_tasks,
    vocab_size=1000,   # dummy vocab size for now
    hidden_dim=256,
).to(device)

# criterion = imports.DySTANceLoss(
#     surrogate_type="logistic",
#     lambda_entropy=0.05,
# )

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-5,,
    weight_decay=1e-4,
)


def build_context_tensors(ctx_mgr, task_idx, model, device):
    """
    Builds task-conditional context tensors for all tools.

    Returns:
        ctx_img_feat : [M, C, Dx] on `device`
        ctx_gt       : [M, C]     on `device`
        ctx_pred     : [M, C]     on `device`
    """
    ctx_img_feats = []
    ctx_gts = []
    ctx_preds = []

    M = ctx_mgr.dataset.M # number of tools
    C = ctx_mgr.examples_per_tool # number of examples

    for tool_idx in range(M):
        ctx = ctx_mgr.sample_context(tool_idx, task_idx)

        if ctx is None:
            # No valid context for this tool-task pair
            ctx_img_feats.append(
                torch.zeros(C, model.img_dim, device=device)
            )
            ctx_gts.append(
                torch.zeros(C, device=device)
            )
            ctx_preds.append(
                torch.zeros(C, device=device)
            )
        else:
            imgs, gt, preds = ctx

            imgs = imgs.to(device)
            gt = gt.to(device)
            preds = preds.to(device)

            with torch.no_grad():
                feats = model.extract_img_feat(imgs)  # [C, Dx]

            ctx_img_feats.append(feats)
            ctx_gts.append(gt)
            ctx_preds.append(preds)

    return (
        torch.stack(ctx_img_feats, dim=0),  # [M, C, Dx]
        torch.stack(ctx_gts, dim=0),        # [M, C]
        torch.stack(ctx_preds, dim=0),      # [M, C]
    )

In [7]:
# def train_one_epoch(model, loader, ctx_mgr, optimizer, criterion):
#     model.train()

#     total_loss = 0.0
#     total_batches = 0

#     for batch in tqdm(loader, desc="Training"):
#         images = batch["image"].to(device)        # [B, 3, H, W]
#         gt_all = batch["gt"].to(device)           # [B, L]
#         preds_all = batch["tool_preds"].to(device)  # [B, M, L]
#         mask_all = batch["tool_mask"].to(device)    # [B, M, L]

#         B = images.size(0)

#         # ------------------------------------------------------------
#         # 1) Sample a task uniformly
#         # ------------------------------------------------------------
#         task_idx = random.randint(0, num_tasks - 1)
#         task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

#         # Task-conditional slices
#         gt = gt_all[:, task_idx]                 # [B]
#         tool_preds = preds_all[:, :, task_idx]   # [B, M]
#         tool_mask  = mask_all[:, :, task_idx]    # [B, M]

#         # ------------------------------------------------------------
#         # 2) Build context for this task
#         # ------------------------------------------------------------
#         ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
#             ctx_mgr, task_idx, device
#         )

#         # ------------------------------------------------------------
#         # 3) Forward pass
#         # ------------------------------------------------------------
#         scores = model(
#             images=images,
#             text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text
#             task_idx=task_ids,
#             tool_preds=tool_preds,
#             ctx_img_feat=ctx_img_feat,
#             ctx_gt=ctx_gt,
#             ctx_pred=ctx_pred,
#             tool_mask=tool_mask,
#         )

#         # ------------------------------------------------------------
#         # 4) Compute costs (classification task)
#         # c_E = 1 - confidence on true label
#         # ------------------------------------------------------------
#         tool_costs = 1.0 - tool_preds  # [B, M]

#         # ------------------------------------------------------------
#         # 5) Loss + backward
#         # ------------------------------------------------------------
#         loss, logs = criterion(scores, tool_costs, tool_mask)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         total_batches += 1

#     return total_loss / max(1, total_batches)


In [8]:
from typing import Optional
import torch
from tqdm import tqdm

# helper for empty context
def _empty_context(M, model, device):
    Dx = model.img_dim if hasattr(model, "img_dim") else 512
    return (
        torch.zeros((M, 0, Dx), device=device),
        torch.zeros((M, 0), device=device),
        torch.zeros((M, 0), device=device),
    )

def train_one_epoch_all_tasks(
    model,
    loader,
    ctx_mgr,
    optimizer,
    criterion,  # expects (router_logits, tool_probs, gt, validity_mask) -> (loss, info)
    num_tasks: int,
    device: torch.device,
    task_weights: Optional[torch.Tensor] = None,
    resample_per_batch: bool = False,
):
    """
    Train one epoch: compute routing loss for EVERY task per batch, average across tasks,
    one optimizer step per batch.

    Assumptions:
      - batch["tool_preds"][:, :, t] is P(y=1) for each tool on task t (binary)
      - batch["gt"][:, t] in {0,1}
      - criterion implements Eq.(7)-style comp-sum surrogate consistently with those semantics

    Features:
      - Context tensors are cached per task within an epoch for efficiency and rebuilt each epoch to prevent cross-epoch leakage.
    """
    model.train()
    total_loss = 0.0
    total_batches = 0

    # Task weights
    if task_weights is None:
        task_weights = torch.ones((num_tasks,), dtype=torch.float32, device=device)
    else:
        task_weights = task_weights.to(device).float()

    # Cache contexts per task (huge speedup). If resample_per_batch=True we
    # will ignore this cache and rebuild contexts each batch.
    ctx_cache = {}

    for batch in tqdm(loader, desc="Training"):
        images    = batch["image"].to(device)       # [B,3,H,W]
        gt_all    = batch["gt"].to(device)          # [B,T]
        probs_all = batch["tool_preds"].to(device)  # [B,M,T]  (P(y=1))
        mask_all  = batch["tool_mask"].to(device)   # [B,M,T]  (0/1)

        B = images.size(0)
        M = probs_all.size(1)

        loss_sum_tasks = 0.0
        weight_sum = 0.0

        for t in range(num_tasks):
            # Task slices
            gt        = gt_all[:, t]           # [B]
            tool_probs = probs_all[:, :, t]    # [B,M]
            tool_mask  = mask_all[:, :, t]     # [B,M]

            # Keep only samples with at least one valid tool
            valid_counts = tool_mask.sum(dim=1)          # [B] -- in this batch, how many tools are valid?
            valid_samples = valid_counts > 0             # [B] bool -- are there any valid tools in this batch?
            if valid_samples.sum().item() == 0:          # skip if no valid tools in batch
                continue

            idxs = torch.nonzero(valid_samples, as_tuple=False).squeeze(1)
            images_v = images[idxs]
            gt_v = gt[idxs]
            tool_probs_v = tool_probs[idxs]
            tool_mask_v  = tool_mask[idxs]
            Bv = idxs.numel()

            # === CONTEXT EMBEDDING: reuse epoch cache OR resample per batch (streamlined) ===
            use_cache = (not resample_per_batch) and (t in ctx_cache)
            if use_cache:
                ctx_img_feat, ctx_gt, ctx_pred = ctx_cache[t]
            else:
                # build (either because we are resampling each batch, or first time this epoch)
                ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, model, device=device)

                # normalize missing context to zero-shaped placeholders
                if ctx_img_feat is None:
                    ctx_img_feat, ctx_gt, ctx_pred = _empty_context(M, model, device)
                else:
                    # make sure everything is on the right device (idempotent)
                    ctx_img_feat = ctx_img_feat.to(device)
                    ctx_gt = ctx_gt.to(device)
                    ctx_pred = ctx_pred.to(device)

                # only store in epoch cache when using epoch-caching mode
                if not resample_per_batch:
                    ctx_cache[t] = (ctx_img_feat, ctx_gt, ctx_pred)


            # Forward (router logits over tools)
            task_ids = torch.full((Bv,), t, device=device, dtype=torch.long)
            router_logits = model(
                images=images_v,
                text_tokens=torch.zeros((Bv, 1), dtype=torch.long, device=device),
                task_idx=task_ids,
                tool_preds=tool_probs_v,   # pass tool probs if model uses them as features
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask_v,
            )  # [Bv, M]

            # Compute loss (NO manual tool_costs here)
            loss_t, logs = criterion(
                router_logits=router_logits,
                tool_probs=tool_probs_v,
                gt=gt_v,
                validity_mask=tool_mask_v,
            )

            w_t = task_weights[t]
            loss_sum_tasks = loss_sum_tasks + (w_t * loss_t)
            weight_sum = weight_sum + w_t

        if weight_sum.item() == 0:
            continue

        loss_batch = loss_sum_tasks / weight_sum

        optimizer.zero_grad(set_to_none=True)
        loss_batch.backward()
        optimizer.step()

        total_loss += float(loss_batch.detach().cpu().item())
        total_batches += 1

    return total_loss / max(1, total_batches)


In [9]:
# @torch.no_grad()
# def evaluate(model, loader, ctx_mgr):
#     """
#     Evaluates DySTANce router on:
#       - average regret
#       - per-task accuracy (binary)

#     Task is sampled per batch (same as training).
#     """
#     model.eval()

#     total_regret = 0.0
#     total_samples = 0
#     correct_total = 0

#     for batch in tqdm(loader, desc="Validation"):
#         images = batch["image"].to(device)
#         gt_all = batch["gt"].to(device)                # [B, L]
#         preds_all = batch["tool_preds"].to(device)     # [B, M, L]
#         mask_all = batch["tool_mask"].to(device)       # [B, M, L]

#         B = images.size(0)

#         # ------------------------------------------------------------
#         # Sample a task (same protocol as training)
#         # ------------------------------------------------------------
#         task_idx = random.randint(0, num_tasks - 1)
#         task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

#         gt = gt_all[:, task_idx]                       # [B]
#         tool_preds = preds_all[:, :, task_idx]         # [B, M]
#         tool_mask  = mask_all[:, :, task_idx]          # [B, M]

#         # ------------------------------------------------------------
#         # Build task-conditional context
#         # ------------------------------------------------------------
#         ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
#             ctx_mgr, task_idx, device
#         )

#         # ------------------------------------------------------------
#         # Forward pass
#         # ------------------------------------------------------------
#         scores = model(
#             images,
#             torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text
#             task_ids,
#             tool_preds,
#             ctx_img_feat,
#             ctx_gt,
#             ctx_pred,
#             tool_mask,
#         )

#         # ------------------------------------------------------------
#         # Routing decision
#         # ------------------------------------------------------------
#         chosen = scores.argmax(dim=1)  # [B]

#         # ------------------------------------------------------------
#         # Regret computation
#         # ------------------------------------------------------------
#         costs = 1.0 - tool_preds       # [B, M]

#         chosen_cost = costs[torch.arange(B), chosen]
#         oracle_cost = costs.masked_fill(tool_mask == 0, 1e9).min(dim=1).values

#         total_regret += (chosen_cost - oracle_cost).sum().item()

#         # ------------------------------------------------------------
#         # Accuracy computation (binary)
#         # ------------------------------------------------------------
#         chosen_preds = tool_preds[torch.arange(B), chosen]  # [B]
#         chosen_labels = (chosen_preds >= 0.5).long()

#         correct_total += (chosen_labels == gt).sum().item()
#         total_samples += B

#     avg_regret = total_regret / max(1, total_samples)
#     accuracy = correct_total / max(1, total_samples)

#     return {
#         "avg_regret": avg_regret,
#         "accuracy": accuracy,
#         "num_samples": total_samples,
#     }


In [10]:
# @torch.no_grad()
# def evaluate_all_tasks(model, loader, ctx_mgr):
#     """
#     Evaluate DySTANce over EVERY task per batch (no random single-task sampling).

#     Returns a dict with:
#       - avg_regret      : average regret across all considered (sample,task) pairs
#       - accuracy        : overall accuracy (router chosen tool vs GT)
#       - per_task_acc    : list of per-task accuracies
#       - random_acc      : random-router baseline accuracy
#       - upper_bound_acc : oracle upper-ceiling accuracy
#       - num_pairs       : number of considered (sample,task) pairs
#     """
#     model.eval()

#     # Accumulators
#     total_regret = 0.0
#     total_correct = 0
#     total_pairs = 0

#     # Per-task accumulators
#     per_task_correct = [0 for _ in range(num_tasks)]
#     per_task_pairs = [0 for _ in range(num_tasks)]

#     # Baselines
#     random_correct = 0
#     upper_correct = 0

#     # Cache contexts per task to avoid repeated builds
#     ctx_cache = {}

#     for batch in tqdm(loader, desc="Validation"):
#         images = batch["image"].to(device)            # [B, C, H, W]
#         gt_all  = batch["gt"].to(device)             # [B, L]
#         preds_all = batch["tool_preds"].to(device)   # [B, M, L]
#         mask_all  = batch["tool_mask"].to(device)    # [B, M, L]

#         B = images.shape[0]

#         # For each task, evaluate all samples in batch (but skip samples with no valid tool)
#         for t in range(num_tasks):
#             # Task slices
#             gt = gt_all[:, t]                       # [B]
#             tool_preds = preds_all[:, :, t]         # [B, M]
#             tool_mask  = mask_all[:, :, t]          # [B, M] {0,1}

#             # Determine which samples in this batch have at least one valid tool
#             valid_tool_counts = tool_mask.sum(dim=1)        # [B]
#             valid_samples_mask = (valid_tool_counts > 0)   # [B] bool

#             if valid_samples_mask.sum().item() == 0:
#                 # no sample in this batch has any valid tool for this task -> skip
#                 continue

#             # Build / fetch cached context for task t (move to device)
#             if t in ctx_cache:
#                 ctx_img_feat, ctx_gt, ctx_pred = ctx_cache[t]
#             else:
#                 ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, device=device)
#                 # move to device (build_context_tensors may return CPU tensors)
#                 if ctx_img_feat is not None:
#                     ctx_img_feat = ctx_img_feat.to(device)
#                     ctx_gt = ctx_gt.to(device)
#                     ctx_pred = ctx_pred.to(device)
#                 else:
#                     # Create zero-sized contexts if none exist (ANP should handle B_t=0)
#                     M = preds_all.shape[1]
#                     Dx = model.img_dim if hasattr(model, "img_dim") else 512
#                     ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
#                     ctx_gt       = torch.zeros((M, 0), device=device)
#                     ctx_pred     = torch.zeros((M, 0), device=device)

#                 ctx_cache[t] = (ctx_img_feat, ctx_gt, ctx_pred)

#             # Build task ids for the batch
#             task_ids = torch.full((B,), t, dtype=torch.long, device=device)

#             # Forward pass (model expected to handle masking internally)
#             scores = model(
#                 images=images,
#                 text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text (or real tokens if available)
#                 task_idx=task_ids,
#                 tool_preds=tool_preds,
#                 ctx_img_feat=ctx_img_feat,
#                 ctx_gt=ctx_gt,
#                 ctx_pred=ctx_pred,
#                 tool_mask=tool_mask,
#             )  # [B, M]

#             # Ensure invalid tools have very low score (safety)
#             scores = scores.masked_fill(tool_mask == 0, -1e9)

#             # Choose best tool per sample
#             chosen = scores.argmax(dim=1)  # [B] (may point to invalid tool if all -inf; but we filtered such samples)

#             # Consider only samples with at least one valid tool
#             idxs = torch.nonzero(valid_samples_mask, as_tuple=False).squeeze(1)  # [B_valid]
#             if idxs.numel() == 0:
#                 continue

#             # Subset arrays to valid samples
#             chosen_v = chosen[idxs]                          # [B_valid]
#             gt_v = gt[idxs]                                 # [B_valid]
#             preds_v = tool_preds[idxs]                      # [B_valid, M]
#             mask_v  = tool_mask[idxs]                       # [B_valid, M]

#             # Costs: c = 1 - confidence (soft proxy)
#             costs_v = 1.0 - preds_v                         # [B_valid, M]

#             # Chosen cost and oracle cost
#             chosen_cost = costs_v[torch.arange(idxs.numel(), device=device), chosen_v]  # [B_valid]

#             # Oracle: min cost among valid tools
#             inf_mask = (~(mask_v.bool())).float() * 1e9
#             costs_for_oracle = costs_v + inf_mask
#             oracle_cost = costs_for_oracle.min(dim=1).values  # [B_valid]

#             total_regret += (chosen_cost - oracle_cost).sum().item()

#             # Accuracy: interpret chosen tool's prediction as binary label (>=0.5)
#             chosen_pred_probs = preds_v[torch.arange(idxs.numel(), device=device), chosen_v]  # [B_valid]
#             chosen_labels = (chosen_pred_probs >= 0.5).long()
#             correct = (chosen_labels == gt_v).sum().item()
#             total_correct += int(correct)
#             total_pairs += int(idxs.numel())

#             # Per-task accounting
#             per_task_correct[t] += int(correct)
#             per_task_pairs[t] += int(idxs.numel())

#             # --- Baselines for these valid samples --------------------------------
#             # Random-router baseline: choose uniformly among valid tools for each sample
#             # We implement this in vectorized-ish manner:
#             Bv, M = preds_v.shape
#             # Create list of valid tool indices per sample
#             # We'll loop here across Bv (Bv is smallish per batch)
#             for i in range(Bv):
#                 valid_indices = torch.nonzero(mask_v[i].bool(), as_tuple=False).squeeze(1)
#                 if valid_indices.numel() == 0:
#                     # Shouldn't happen due to earlier filtering
#                     continue
#                 # Random pick
#                 rnd_idx = int(valid_indices[torch.randint(0, valid_indices.numel(), (1,)).item()].item())
#                 rnd_prob = preds_v[i, rnd_idx].item()
#                 rnd_label = 1 if rnd_prob >= 0.5 else 0
#                 if rnd_label == int(gt_v[i].item()):
#                     random_correct += 1

#                 # Upper-ceiling: any valid tool predicts correctly?
#                 # If any valid tool's binary prediction equals gt, count as correct
#                 # (treat prob>=0.5 as predicting label=1)
#                 valid_probs = preds_v[i, valid_indices]
#                 valid_preds_bin = (valid_probs >= 0.5).long()
#                 if (valid_preds_bin == int(gt_v[i].item())).any():
#                     upper_correct += 1
#             # ----------------------------------------------------------------------

#     # After all batches
#     avg_regret = total_regret / max(1, total_pairs)
#     accuracy = total_correct / max(1, total_pairs)
#     per_task_acc = [ (per_task_correct[t] / per_task_pairs[t]) if per_task_pairs[t] > 0 else None
#                      for t in range(num_tasks) ]

#     random_acc = random_correct / max(1, total_pairs)
#     upper_acc  = upper_correct / max(1, total_pairs)

#     return {
#         "avg_regret": avg_regret,
#         "accuracy": accuracy,
#         "per_task_acc": per_task_acc,
#         "random_acc": random_acc,
#         "upper_acc": upper_acc,
#         "num_pairs": total_pairs,
#     }


In [11]:
# @torch.no_grad()
# def compute_upper_ceiling_accuracy(
#     dataloader,
#     task_idx: int,
#     device="cpu",
# ):
#     """
#     Computes the upper-ceiling accuracy for a single task.

#     Upper-ceiling = fraction of samples for which at least one
#     valid tool predicts the correct label.

#     Args:
#         dataloader : DataLoader yielding OpenIRoutedDataset batches
#         task_idx   : int, which task (label) to evaluate
#         device     : torch device

#     Returns:
#         ceiling_acc : float in [0,1]
#         stats       : dict with additional diagnostics
#     """
#     total = 0
#     count_ceiling = 0
#     count_any_valid = 0

#     for batch in tqdm(dataloader, desc=f"Ceiling task {task_idx}"):
#         gt = batch["gt"][:, task_idx].to(device)            # [B]
#         preds = batch["tool_preds"][:, :, task_idx].to(device)  # [B, M]
#         mask = batch["tool_mask"][:, :, task_idx].to(device)    # [B, M]

#         B, M = preds.shape

#         # Predicted labels per tool (binary classification)
#         pred_labels = (preds >= 0.5).long()  # [B, M]

#         # Ground truth expanded
#         gt_exp = gt.unsqueeze(1).expand(-1, M)  # [B, M]

#         # Correct predictions per tool
#         correct = (pred_labels == gt_exp) & (mask.bool())  # [B, M]

#         # For each sample: does ANY tool get it right?
#         any_correct = correct.any(dim=1)  # [B]

#         # For sanity: does sample have ANY valid tool?
#         any_valid = mask.any(dim=1)       # [B]

#         count_ceiling += any_correct.sum().item()
#         count_any_valid += any_valid.sum().item()
#         total += B

#     ceiling_acc = count_ceiling / max(1, total)

#     stats = {
#         "total_samples": total,
#         "samples_with_any_valid_tool": count_any_valid,
#         "fraction_with_any_valid_tool": count_any_valid / max(1, total),
#     }

#     return ceiling_acc, stats

# for task_idx in range(num_tasks):
#     ceiling_acc, stats = compute_upper_ceiling_accuracy(
#         val_loader,
#         task_idx=task_idx,
#         device=device,
#     )

#     print(f"Upper-ceiling accuracy (task {task_idx}): {ceiling_acc:.4f}")
#     print(stats)


In [12]:
# @torch.no_grad()
# def compute_random_router_accuracy(
#     dataloader,
#     task_idx: int,
#     device="cpu",
#     seed: int = 0,
# ):
#     """
#     Computes accuracy of a random router for a single task.

#     For each sample:
#       - uniformly sample one VALID tool
#       - use its prediction as the output

#     Args:
#         dataloader : DataLoader yielding OpenIRoutedDataset batches
#         task_idx   : int, which task (label) to evaluate
#         device     : torch device
#         seed       : random seed for reproducibility

#     Returns:
#         rand_acc : float in [0,1]
#         stats    : dict
#     """
#     rng = torch.Generator(device=device)
#     rng.manual_seed(seed)

#     total = 0
#     correct_total = 0
#     samples_with_valid = 0

#     for batch in tqdm(dataloader, desc=f"Random router task {task_idx}"):
#         gt = batch["gt"][:, task_idx].to(device)             # [B]
#         preds = batch["tool_preds"][:, :, task_idx].to(device)  # [B, M]
#         mask = batch["tool_mask"][:, :, task_idx].to(device)    # [B, M]

#         B, M = preds.shape

#         for i in range(B):
#             valid_tools = torch.nonzero(mask[i], as_tuple=False).squeeze(-1)

#             if valid_tools.numel() == 0:
#                 # No valid tool: cannot predict (count as incorrect)
#                 total += 1
#                 continue

#             samples_with_valid += 1

#             # Uniform random choice among valid tools
#             j = valid_tools[
#                 torch.randint(
#                     low=0,
#                     high=valid_tools.numel(),
#                     size=(1,),
#                     generator=rng,
#                     device=device,
#                 ).item()
#             ]

#             # Binary prediction
#             pred_label = (preds[i, j] >= 0.5).long()

#             correct_total += (pred_label == gt[i]).item()
#             total += 1

#     rand_acc = correct_total / max(1, total)

#     stats = {
#         "total_samples": total,
#         "samples_with_valid_tool": samples_with_valid,
#         "fraction_with_valid_tool": samples_with_valid / max(1, total),
#     }

#     return rand_acc, stats


# for task_idx in range(num_tasks):
#     rand_acc, stats = compute_random_router_accuracy(
#         val_loader,
#         task_idx=task_idx,
#         device=device,
#     )
#     print(f"Random router accuracy (task {task_idx}): {rand_acc:.4f}")
#     print(stats)


In [13]:
criterion = imports.DySTANceCompSumEq7Loss(
    surrogate_type="logistic",
    lambda_entropy=0.00,
    cost_centering="none",          # or "min" if you want your workaround
    clamp_negative_weights=False,   # optionally True for stability
)

In [14]:
from tqdm import tqdm
import torch
from imports import costs_from_probs_binary

@torch.no_grad()
def evaluate_all_tasks(model, loader, ctx_mgr, num_tasks: int, device: torch.device):
    """
    Evaluate routing via argmax over router scores.

    Assumptions:
      - tool_preds[b, m, t] is P(y=1) for tool m on task t
      - gt_all[b, t] in {0,1}
      - tool_mask[b, m, t] in {0,1}

    Returns:
      - avg_regret : mean(chosen_cost - oracle_cost) using label-aware expected 0-1 costs
      - accuracy   : fraction where chosen tool’s thresholded label matches gt
      - per_task_acc, random_acc, upper_acc, num_pairs
    """
    model.eval()

    total_regret = 0.0
    total_correct = 0
    total_pairs = 0

    per_task_correct = [0 for _ in range(num_tasks)]
    per_task_pairs = [0 for _ in range(num_tasks)]

    random_correct = 0
    upper_correct = 0

    ctx_cache = {}

    for batch in tqdm(loader, desc="Validation"):
        images = batch["image"].to(device)            # [B,C,H,W]
        gt_all = batch["gt"].to(device)               # [B,T]
        preds_all = batch["tool_preds"].to(device)    # [B,M,T] = P(y=1)
        mask_all = batch["tool_mask"].to(device)      # [B,M,T] in {0,1}

        B = images.shape[0]
        M = preds_all.shape[1]

        for t in range(num_tasks):
            gt = gt_all[:, t]                         # [B]
            tool_probs = preds_all[:, :, t]           # [B,M]
            tool_mask = mask_all[:, :, t]             # [B,M]

            valid_counts = tool_mask.sum(dim=1)       # [B]
            valid_samples = valid_counts > 0          # [B] bool
            if valid_samples.sum().item() == 0:
                continue

            # Context cache
            if t in ctx_cache:
                ctx_img_feat, ctx_gt, ctx_pred = ctx_cache[t]
            else:
                ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, model, device=device)
                if ctx_img_feat is not None:
                    ctx_img_feat = ctx_img_feat.to(device)
                    ctx_gt = ctx_gt.to(device)
                    ctx_pred = ctx_pred.to(device)
                else:
                    Dx = model.img_dim if hasattr(model, "img_dim") else 512
                    ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
                    ctx_gt       = torch.zeros((M, 0), device=device)
                    ctx_pred     = torch.zeros((M, 0), device=device)
                ctx_cache[t] = (ctx_img_feat, ctx_gt, ctx_pred)

            task_ids = torch.full((B,), t, dtype=torch.long, device=device)

            scores = model(
                images=images,
                text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),
                task_idx=task_ids,
                tool_preds=tool_probs,
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask,
            )  # [B,M]

            # Mask invalid tools
            scores = scores.masked_fill(tool_mask == 0, -1e9)

            # Argmax tool choice
            chosen = scores.argmax(dim=1)  # [B]

            # Keep only samples with at least one valid tool
            idxs = torch.nonzero(valid_samples, as_tuple=False).squeeze(1)
            if idxs.numel() == 0:
                continue

            chosen_v = chosen[idxs]           # [Bv]
            gt_v = gt[idxs]                   # [Bv]
            probs_v = tool_probs[idxs]        # [Bv,M]
            mask_v = tool_mask[idxs]          # [Bv,M]
            Bv = idxs.numel()

            # --------- Regret/oracle costs (label-aware, in [0,1]) ----------
            costs_v = costs_from_probs_binary(probs_v, gt_v)          # [Bv,M]
            costs_v = costs_v.masked_fill(mask_v == 0, 1e9)           # invalid -> huge

            chosen_cost = costs_v[torch.arange(Bv, device=device), chosen_v]  # [Bv]
            oracle_cost = costs_v.min(dim=1).values                              # [Bv]

            total_regret += (chosen_cost - oracle_cost).sum().item()

            # --------- Accuracy (thresholded label of chosen tool) ----------
            chosen_probs = probs_v[torch.arange(Bv, device=device), chosen_v]   # [Bv]
            chosen_labels = (chosen_probs >= 0.5).long()
            correct = (chosen_labels == gt_v).sum().item()

            total_correct += int(correct)
            total_pairs += int(Bv)
            per_task_correct[t] += int(correct)
            per_task_pairs[t] += int(Bv)

            # --------- Random baseline (uniform among valid tools) ----------
            # multinomial over the {0,1} mask gives uniform among valid entries
            sampled = torch.multinomial(mask_v.float(), num_samples=1).squeeze(1)  # [Bv]
            rnd_probs = probs_v[torch.arange(Bv, device=device), sampled]
            rnd_labels = (rnd_probs >= 0.5).long()
            random_correct += int((rnd_labels == gt_v).sum().item())

            # --------- Upper bound accuracy: any valid tool predicts correctly ----------
            preds_bin = (probs_v >= 0.5).long()                     # [Bv,M]
            gt_expand = gt_v.unsqueeze(1).expand_as(preds_bin)      # [Bv,M]
            any_correct = ((preds_bin == gt_expand) & (mask_v.bool())).any(dim=1)
            upper_correct += int(any_correct.sum().item())

    avg_regret = total_regret / max(1, total_pairs)
    accuracy = total_correct / max(1, total_pairs)
    per_task_acc = [
        (per_task_correct[t] / per_task_pairs[t]) if per_task_pairs[t] > 0 else None
        for t in range(num_tasks)
    ]
    random_acc = random_correct / max(1, total_pairs)
    upper_acc = upper_correct / max(1, total_pairs)

    return {
        "avg_regret": avg_regret,
        "accuracy": accuracy,
        "per_task_acc": per_task_acc,
        "random_acc": random_acc,
        "upper_acc": upper_acc,
        "num_pairs": total_pairs,
    }


In [15]:
# loader = train_loader
# task_weights = None
# resample_per_batch = False
# ########################################################
# model.train()
# total_loss = 0.0
# total_batches = 0

# # Task weights
# if task_weights is None:
#     task_weights = torch.ones((num_tasks,), dtype=torch.float32, device=device) # [T]
# else:
#     task_weights = task_weights.to(device).float()

# # Cache contexts per task (huge speedup)
# ctx_cache = {}

# for batch in tqdm(loader, desc="Training"):
#     images    = batch["image"].to(device)       # [B,3,H,W]
#     gt_all    = batch["gt"].to(device)          # [B,T]
#     probs_all = batch["tool_preds"].to(device)  # [B,M,T]  (P(y=1))
#     mask_all  = batch["tool_mask"].to(device)   # [B,M,T]  (0/1)

#     B = images.size(0)
#     M = probs_all.size(1)

#     loss_sum_tasks = 0.0
#     weight_sum = 0.0

#     for t in range(num_tasks): # loop over all tasks
#         # get Task slices
#         gt        = gt_all[:, t]           # [B]
#         tool_probs = probs_all[:, :, t]    # [B,M] -- some will be 0.5 (invalid)
#         tool_mask  = mask_all[:, :, t]     # [B,M]

#         # Keep only samples with at least one valid tool
#         valid_counts = tool_mask.sum(dim=1)          # [B] -- in this batch, how many tools are valid?
#         valid_samples = valid_counts > 0             # [B] bool -- are there any valid tools in this batch?
#         if valid_samples.sum().item() == 0:          # skip if no valid tools in batch
#             continue

#         # get valid samples from batch 
#         idxs = torch.nonzero(valid_samples, as_tuple=False).squeeze(1) # [Bv] -- get indices of samples with valid tools
#         images_v = images[idxs]
#         gt_v = gt[idxs]
#         tool_probs_v = tool_probs[idxs]
#         tool_mask_v  = tool_mask[idxs]
#         Bv = idxs.numel()

#         # === CONTEXT: reuse epoch cache OR resample per batch (streamlined) ===
#         use_cache = (not resample_per_batch) and (t in ctx_cache)
#         if use_cache:
#             ctx_img_feat, ctx_gt, ctx_pred = ctx_cache[t]
#         else:
#             # build (either because we are resampling each batch, or first time this epoch)
#             ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, model, device=device)

#             # normalize missing context to zero-shaped placeholders
#             if ctx_img_feat is None:
#                 ctx_img_feat, ctx_gt, ctx_pred = _empty_context(M, model, device)
#             else:
#                 # make sure everything is on the right device (idempotent)
#                 ctx_img_feat = ctx_img_feat.to(device)
#                 ctx_gt = ctx_gt.to(device)
#                 ctx_pred = ctx_pred.to(device)

#             # only store in epoch cache when using epoch-caching mode
#             if not resample_per_batch:
#                 ctx_cache[t] = (ctx_img_feat, ctx_gt, ctx_pred)


#         # Forward (router logits over tools)
#         task_ids = torch.full((Bv,), t, device=device, dtype=torch.long) 
#         router_logits = model(
#             images=images_v,
#             text_tokens=torch.zeros((Bv, 1), dtype=torch.long, device=device),
#             task_idx=task_ids,
#             tool_preds=tool_probs_v,   # pass tool probs if model uses them as features
#             ctx_img_feat=ctx_img_feat,
#             ctx_gt=ctx_gt,
#             ctx_pred=ctx_pred,
#             tool_mask=tool_mask_v,
#         )  # [Bv, M]
#         break

#         # Compute loss (NO manual tool_costs here)
#         loss_t, logs = criterion(
#             router_logits=router_logits,
#             tool_probs=tool_probs_v,
#             gt=gt_v,
#             validity_mask=tool_mask_v,
#         )

#         w_t = task_weights[t]
#         loss_sum_tasks = loss_sum_tasks + (w_t * loss_t)
#         weight_sum = weight_sum + w_t
#     break

#     # if weight_sum.item() == 0:
#     #     continue

#     # loss_batch = loss_sum_tasks / weight_sum

#     # optimizer.zero_grad(set_to_none=True)
#     # loss_batch.backward()
#     # optimizer.step()

#     # total_loss += float(loss_batch.detach().cpu().item())
#     # total_batches += 1

# # fin = total_loss / max(1, total_batches)

In [16]:
# eps = 1e-8
# import torch.nn.functional as F
# ### loss dev

# B, M = router_logits.shape
# device = router_logits.device

# validity_mask = tool_mask_v.to(device)
# tool_probs = tool_probs_v.to(device)
# gt = gt_v.to(device)

# # ------------
# # 1) Masked softmax -> pi over valid tools
# # ------------
# very_neg = -1e9
# masked_logits = router_logits.masked_fill(validity_mask == 0, very_neg) # --- e^-very_neg = 0
# pi = F.softmax(masked_logits, dim=1)                 # [B,M] -- softmax over the tool logits
# pi = pi * validity_mask                              # zero invalid
# pi = pi / (pi.sum(dim=1, keepdim=True) + eps)   # renormalize -- make sure sum to 1 after masking

# # Effective panel size
# m_eff = validity_mask.sum(dim=1, keepdim=True)       # [B,1] -- sum over valid tools. Asks how many tools are valid for each sample.
# m_eff_clamped = torch.clamp(m_eff, min=1.0)

In [17]:
# cost_centering = "none"
# # ------------
# # 2) Compute costs in [0,1] from tool_probs and gt
# #    c = y*(1-p) + (1-y)*p
# # ------------
# gt_f = gt.float().unsqueeze(1)                       # [B,1]
# tool_costs = gt_f * (1.0 - tool_probs) + (1.0 - gt_f) * tool_probs  # [B,M]
# tool_costs = tool_costs * validity_mask              # invalid -> 0 (masked anyway)

# # Optional per-sample min-centering (heuristic; changes the surrogate)
# if cost_centering == "min":
#     big = 1e9
#     costs_for_min = tool_costs + (1.0 - validity_mask) * big
#     min_cost, _ = costs_for_min.min(dim=1, keepdim=True)
#     min_cost = torch.where(min_cost > big / 2.0, torch.zeros_like(min_cost), min_cost)
#     tool_costs = (tool_costs - min_cost).clamp_min(0.0)  # keep >=0 for valid
# else:
#     min_cost = torch.zeros((B,1), device=device)

In [18]:
# clamp_negative_weights = False
# # ------------
# # 3) Comp-sum weights (variable panel size version)
# #    w_j = sum_{k != j} c_k - m_eff + 2
# # ------------
# sum_costs = tool_costs.sum(dim=1, keepdim=True)       # [B,1]
# w = (sum_costs - tool_costs) - m_eff_clamped + 2.0    # [B,M]
# w = w * validity_mask

# # Optional stabilizer: prevent negative weights (not faithful to Eq.7)
# if clamp_negative_weights:
#     w = torch.clamp(w, min=0.0)

In [None]:
import copy
import math
import torch

# --- early stopping / training loop params ---
num_epochs = 100
patience = 25            # number of epochs with no improvement to wait
min_delta = 1e-4        # minimum improvement in val_regret to count as improvement
verbose = True

# optional: use a scheduler that supports step(metric) like ReduceLROnPlateau
use_scheduler = False
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5, verbose=True)

best_val_regret = float("inf")
best_epoch = -1
best_state = None
epochs_no_improve = 0

for epoch in range(num_epochs):
    train_loss = train_one_epoch_all_tasks(
        model,
        train_loader,
        ctx_mgr,
        optimizer,
        criterion,
        num_tasks=18,
        device=device,
        resample_per_batch=True,
    )

    val_metrics = evaluate_all_tasks(model, val_loader, ctx_mgr, num_tasks=18, device=device)
    val_regret = float(val_metrics["avg_regret"])

    # # Scheduler (optional): if using ReduceLROnPlateau, notify it of validation metric
    # if use_scheduler:
    #     scheduler.step(val_regret)

    improved = (best_val_regret - val_regret) > min_delta

    if improved:
        best_val_regret = val_regret
        best_epoch = epoch
        # keep best model state (deepcopy to be safe)
        best_state = copy.deepcopy(model.state_dict())
        epochs_no_improve = 0
        if verbose:
            print(f"[Epoch {epoch:02d}] val_regret improved -> {val_regret:.6f} (saved best model)")
    else:
        epochs_no_improve += 1
        if verbose:
            print(f"[Epoch {epoch:02d}] no improvement (val_regret {val_regret:.6f}), "
                  f"patience {epochs_no_improve}/{patience}")

    print(
        f"[Epoch {epoch:02d}] Train Loss: {train_loss:.4f} | "
        f"Val Regret: {val_regret:.4f} | Val Acc: {val_metrics['accuracy']:.4f}"
    )

    # Early stopping check
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered (no improvement in {patience} epochs).")
        break

# Restore best weights (if any)
if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Restored best model from epoch {best_epoch} with val_regret={best_val_regret:.6f}")
else:
    print("No improvement seen during training; final model kept as-is.")


Training: 100%|██████████| 53/53 [04:56<00:00,  5.60s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.53it/s]


[Epoch 00] val_regret improved -> 0.069494 (saved best model)
[Epoch 00] Train Loss: -315.8013 | Val Regret: 0.0695 | Val Acc: 0.9080


Training: 100%|██████████| 53/53 [05:00<00:00,  5.67s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


[Epoch 01] val_regret improved -> 0.063770 (saved best model)
[Epoch 01] Train Loss: -835.6876 | Val Regret: 0.0638 | Val Acc: 0.9038


Training: 100%|██████████| 53/53 [05:00<00:00,  5.67s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


[Epoch 02] val_regret improved -> 0.061070 (saved best model)
[Epoch 02] Train Loss: -962.0937 | Val Regret: 0.0611 | Val Acc: 0.9104


Training: 100%|██████████| 53/53 [04:59<00:00,  5.65s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.60it/s]


[Epoch 03] no improvement (val_regret 0.065450), patience 1/25
[Epoch 03] Train Loss: -606.0123 | Val Regret: 0.0654 | Val Acc: 0.9030


Training: 100%|██████████| 53/53 [04:57<00:00,  5.62s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.58it/s]


[Epoch 04] no improvement (val_regret 0.066341), patience 2/25
[Epoch 04] Train Loss: -314.8905 | Val Regret: 0.0663 | Val Acc: 0.9116


Training: 100%|██████████| 53/53 [04:56<00:00,  5.59s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


[Epoch 05] val_regret improved -> 0.058023 (saved best model)
[Epoch 05] Train Loss: -327.2323 | Val Regret: 0.0580 | Val Acc: 0.9106


Training: 100%|██████████| 53/53 [04:55<00:00,  5.58s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


[Epoch 06] no improvement (val_regret 0.064881), patience 1/25
[Epoch 06] Train Loss: -811.3403 | Val Regret: 0.0649 | Val Acc: 0.9080


Training: 100%|██████████| 53/53 [04:59<00:00,  5.65s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


[Epoch 07] no improvement (val_regret 0.068046), patience 2/25
[Epoch 07] Train Loss: -1117.5113 | Val Regret: 0.0680 | Val Acc: 0.9064


Training: 100%|██████████| 53/53 [04:54<00:00,  5.56s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.58it/s]


[Epoch 08] val_regret improved -> 0.056754 (saved best model)
[Epoch 08] Train Loss: -1451.2742 | Val Regret: 0.0568 | Val Acc: 0.9120


Training: 100%|██████████| 53/53 [04:58<00:00,  5.63s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


[Epoch 09] no improvement (val_regret 0.075552), patience 1/25
[Epoch 09] Train Loss: -1667.3506 | Val Regret: 0.0756 | Val Acc: 0.8941


Training: 100%|██████████| 53/53 [04:51<00:00,  5.50s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.62it/s]


[Epoch 10] no improvement (val_regret 0.077764), patience 2/25
[Epoch 10] Train Loss: -1415.1578 | Val Regret: 0.0778 | Val Acc: 0.8817


Training: 100%|██████████| 53/53 [04:55<00:00,  5.58s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


[Epoch 11] no improvement (val_regret 0.059901), patience 3/25
[Epoch 11] Train Loss: -1691.2141 | Val Regret: 0.0599 | Val Acc: 0.9122


Training: 100%|██████████| 53/53 [04:53<00:00,  5.54s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.68it/s]


[Epoch 12] no improvement (val_regret 0.067156), patience 4/25
[Epoch 12] Train Loss: -1807.3988 | Val Regret: 0.0672 | Val Acc: 0.8979


Training: 100%|██████████| 53/53 [04:56<00:00,  5.60s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.60it/s]


[Epoch 13] no improvement (val_regret 0.079558), patience 5/25
[Epoch 13] Train Loss: -1754.7618 | Val Regret: 0.0796 | Val Acc: 0.8891


Training: 100%|██████████| 53/53 [05:00<00:00,  5.67s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.56it/s]


[Epoch 14] no improvement (val_regret 0.087507), patience 6/25
[Epoch 14] Train Loss: -1829.2562 | Val Regret: 0.0875 | Val Acc: 0.8582


Training: 100%|██████████| 53/53 [04:55<00:00,  5.57s/it]
Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


[Epoch 15] no improvement (val_regret 0.064986), patience 7/25
[Epoch 15] Train Loss: -1848.0772 | Val Regret: 0.0650 | Val Acc: 0.9057


Training:  68%|██████▊   | 36/53 [03:21<01:35,  5.63s/it]

In [None]:
import torch
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

# re-use helper from earlier fix
def costs_from_probs_binary(preds: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
    """
    preds: [B, M] probabilities P(y=1)
    gt:    [B] binary {0,1}
    returns costs [B, M] in [0,1] = expected 0-1 error w.r.t. gt
    """
    gt_f = gt.float().unsqueeze(1)  # [B,1]
    return gt_f * (1.0 - preds) + (1.0 - gt_f) * preds

@torch.no_grad()
def inspect_router_choices(model, test_loader, ctx_mgr, num_tasks, device, max_examples_per_tool=5):
    """
    Runs evaluation and returns diagnostics about which tool is chosen and how it performs.

    Returns a dict with:
      - per_task_tool_counts[t][m] = count of times tool m was chosen for task t
      - per_task_tool_correct[t][m] = number of correct predictions when tool m was chosen
      - per_task_tool_regret_sum[t][m] = sum of regret when tool m was chosen
      - per_task_total_pairs[t] = total valid (sample,task) pairs considered
      - overall_df: pandas DataFrame with one row per valid (sample,task):
            columns: ['global_idx','task','chosen_tool','chosen_prob','chosen_label','gt',
                      'chosen_cost','oracle_tool','oracle_cost','regret','mask_valid_tools']
      - examples_by_tool[(t,m)] = list of example dicts (up to max_examples_per_tool) for quick inspection
    """
    model.eval()

    # accumulators
    per_task_tool_counts = [defaultdict(int) for _ in range(num_tasks)]
    per_task_tool_correct = [defaultdict(int) for _ in range(num_tasks)]
    per_task_tool_regret_sum = [defaultdict(float) for _ in range(num_tasks)]
    per_task_total_pairs = [0 for _ in range(num_tasks)]

    rows = []  # to build DataFrame of per-(sample,task) rows
    examples_by_tool = defaultdict(list)  # key = (task, tool) -> list of examples

    global_idx_base = 0  # optional index to identify sample across batches

    for batch in tqdm(test_loader, desc="Inspecting router"):
        images = batch["image"].to(device)            # [B,C,H,W]
        gt_all = batch["gt"].to(device)               # [B,T]
        preds_all = batch["tool_preds"].to(device)    # [B,M,T] = P(y=1)
        mask_all = batch["tool_mask"].to(device)      # [B,M,T]

        B = images.shape[0]
        M = preds_all.shape[1]

        # iterate tasks (same as eval)
        for t in range(num_tasks):
            gt = gt_all[:, t]                       # [B]
            tool_probs = preds_all[:, :, t]         # [B, M]
            tool_mask  = mask_all[:, :, t]          # [B, M]

            # skip if no valid tool for any sample in batch
            valid_counts = tool_mask.sum(dim=1)     # [B]
            valid_samples_mask = (valid_counts > 0) # [B] bool
            if valid_samples_mask.sum().item() == 0:
                continue

            # build context (cache if you already do)
            ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, device=device, model=model)
            if ctx_img_feat is not None:
                ctx_img_feat = ctx_img_feat.to(device)
                ctx_gt = ctx_gt.to(device)
                ctx_pred = ctx_pred.to(device)
            else:
                Dx = model.img_dim if hasattr(model, "img_dim") else 512
                ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
                ctx_gt       = torch.zeros((M, 0), device=device)
                ctx_pred     = torch.zeros((M, 0), device=device)

            # forward in one shot for full batch (model should handle masking)
            task_ids = torch.full((B,), t, dtype=torch.long, device=device)
            scores = model(
                images=images,
                text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),
                task_idx=task_ids,
                tool_preds=tool_probs,
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask,
            )  # [B, M]

            # mask invalid
            scores = scores.masked_fill(tool_mask == 0, -1e9)

            # chosen tool by argmax
            chosen = scores.argmax(dim=1)  # [B]

            # Work on valid sample subset
            idxs = torch.nonzero(valid_samples_mask, as_tuple=False).squeeze(1)
            if idxs.numel() == 0:
                continue

            chosen_v = chosen[idxs]                     # [Bv]
            gt_v = gt[idxs]                             # [Bv]
            probs_v = tool_probs[idxs]                  # [Bv, M]
            mask_v = tool_mask[idxs]                    # [Bv, M]
            Bv = idxs.numel()

            # compute costs and oracle
            costs_v = costs_from_probs_binary(probs_v, gt_v)   # [Bv,M]
            costs_for_oracle = costs_v.masked_fill(mask_v == 0, 1e9)
            oracle_costs, oracle_idx = costs_for_oracle.min(dim=1)  # [Bv], [Bv]
            chosen_cost = costs_v[torch.arange(Bv, device=device), chosen_v]  # [Bv]

            # compute regret
            regret = (chosen_cost - oracle_costs)  # [Bv]

            # chosen predicted labels and correctness (threshold 0.5)
            chosen_probs = probs_v[torch.arange(Bv, device=device), chosen_v]  # [Bv]
            chosen_labels = (chosen_probs >= 0.5).long()
            correct_mask = (chosen_labels == gt_v).long()

            # aggregate per-tool stats for this task
            for i in range(Bv):
                tool_idx = int(chosen_v[i].item())
                per_task_tool_counts[t][tool_idx] += 1
                per_task_tool_correct[t][tool_idx] += int(correct_mask[i].item())
                per_task_tool_regret_sum[t][tool_idx] += float(regret[i].item())
                per_task_total_pairs[t] += 1

                # record row for DataFrame (global_idx optional)
                row = {
                    "global_idx": int(global_idx_base + idxs[i].item()),
                    "task": int(t),
                    "chosen_tool": tool_idx,
                    "chosen_prob": float(chosen_probs[i].item()),
                    "chosen_label": int(chosen_labels[i].item()),
                    "gt": int(gt_v[i].item()),
                    "chosen_cost": float(chosen_cost[i].item()),
                    "oracle_tool": int(oracle_idx[i].item()),
                    "oracle_cost": float(oracle_costs[i].item()),
                    "regret": float(regret[i].item()),
                    "valid_tools_mask": mask_v[i].cpu().numpy().tolist(),
                }
                rows.append(row)

                # Save examples per tool for quick inspection (limit)
                key = (t, tool_idx)
                if len(examples_by_tool[key]) < max_examples_per_tool:
                    examples_by_tool[key].append(row)

            global_idx_base += B  # increment by batch size so indices are unique across batches

    # Build DataFrame
    overall_df = pd.DataFrame(rows)

    # Build final per-task per-tool summary (counts, accuracies, avg regrets)
    per_task_summary = {}
    for t in range(num_tasks):
        counts = per_task_tool_counts[t]
        corrects = per_task_tool_correct[t]
        regrets = per_task_tool_regret_sum[t]
        total = per_task_total_pairs[t]
        tool_summary = {}
        for m in sorted(counts.keys()):
            c = counts[m]
            corr = corrects.get(m, 0)
            reg_sum = regrets.get(m, 0.0)
            tool_summary[m] = {
                "count": c,
                "accuracy": corr / c if c > 0 else None,
                "avg_regret": reg_sum / c if c > 0 else None,
                "share": c / max(1, total),
            }
        per_task_summary[t] = {
            "total_pairs": total,
            "tools": tool_summary
        }

    diagnostics = {
        "per_task_summary": per_task_summary,
        "overall_df": overall_df,
        "examples_by_tool": examples_by_tool,
    }
    return diagnostics


In [None]:
diagnostics = inspect_router_choices(model, test_loader, ctx_mgr, num_tasks, device)

# quick print for first task:
t = 1
print("Task", t, "total pairs:", diagnostics["per_task_summary"][t]["total_pairs"])
for m,info in diagnostics["per_task_summary"][t]["tools"].items():
    print(f" tool {m}: count={info['count']}, acc={info['accuracy']:.3f}, avg_regret={info['avg_regret']:.4f}, share={info['share']:.3f}")

# show a few rows
display(diagnostics["overall_df"].head(20))

# inspect example predictions for a particular (task,tool)
print("Examples where task 0 routed to tool 3:")
for ex in diagnostics["examples_by_tool"][(0,3)]:
    print(ex)


Inspecting router: 100%|██████████| 10/10 [00:57<00:00,  5.77s/it]

Task 1 total pairs: 156
 tool 7: count=16, acc=0.938, avg_regret=0.0410, share=0.103
 tool 9: count=32, acc=1.000, avg_regret=0.0108, share=0.205
 tool 12: count=48, acc=0.938, avg_regret=0.0850, share=0.308
 tool 13: count=16, acc=0.875, avg_regret=0.0524, share=0.103
 tool 15: count=32, acc=0.906, avg_regret=0.0757, share=0.205
 tool 17: count=12, acc=1.000, avg_regret=0.0194, share=0.077





Unnamed: 0,global_idx,task,chosen_tool,chosen_prob,chosen_label,gt,chosen_cost,oracle_tool,oracle_cost,regret,valid_tools_mask
0,0,0,1,0.732694,1,0,0.732694,17,0.021161,0.711532,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
1,1,0,1,0.42432,0,0,0.42432,8,0.085591,0.338729,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
2,2,0,1,0.524372,1,0,0.524372,17,0.044154,0.480218,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
3,3,0,1,0.527293,1,0,0.527293,14,0.1005,0.426792,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
4,4,0,1,0.82665,1,0,0.82665,8,0.027821,0.798829,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
5,5,0,1,0.699087,1,1,0.300913,3,0.290848,0.010065,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
6,6,0,1,0.046845,0,0,0.046845,4,0.001795,0.04505,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
7,7,0,1,0.298619,0,0,0.298619,17,0.013327,0.285292,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
8,8,0,1,0.790129,1,0,0.790129,17,0.006881,0.783249,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."
9,9,0,1,0.105902,0,0,0.105902,17,0.016726,0.089177,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, ..."


Examples where task 0 routed to tool 3:
{'global_idx': 576, 'task': 0, 'chosen_tool': 3, 'chosen_prob': 0.30880802869796753, 'chosen_label': 0, 'gt': 0, 'chosen_cost': 0.30880802869796753, 'oracle_tool': 7, 'oracle_cost': 0.009257563389837742, 'regret': 0.29955047369003296, 'valid_tools_mask': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'global_idx': 577, 'task': 0, 'chosen_tool': 3, 'chosen_prob': 0.11553023010492325, 'chosen_label': 0, 'gt': 0, 'chosen_cost': 0.11553023010492325, 'oracle_tool': 17, 'oracle_cost': 0.011242981068789959, 'regret': 0.10428725183010101, 'valid_tools_mask': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'global_idx': 578, 'task': 0, 'chosen_tool': 3, 'chosen_prob': 0.24244846403598785, 'chosen_label': 0, 'gt': 0, 'chosen_cost': 0.24244846403598785, 'oracle_tool': 17, 'oracle_cost': 0.04872521758079529, 'regret': 0.19372324645519257, 'valid_tools_mask': [1.0, 1.0,

In [None]:
df = diagnostics['overall_df']

In [None]:
import torch
import pandas as pd
from tqdm import tqdm

@torch.no_grad()
def routing_dataframe(
    model,
    loader,
    ctx_mgr,
    num_tasks: int,
    device: torch.device,
    image_id_key: str = "image_id",  # change if your batch uses a different key
):
    """
    Returns a pandas DataFrame with one row per (image, task) routing decision.

    Columns:
      - image_id
      - task
      - gt
      - chosen_tool
      - pred_prob
      - pred_label
    """
    model.eval()
    rows = []

    global_image_idx = 0  # fallback if no image_id provided

    for batch in tqdm(loader, desc="Building routing dataframe"):
        images = batch["image"].to(device)           # [B,C,H,W]
        gt_all = batch["gt"].to(device)              # [B,T]
        preds_all = batch["tool_preds"].to(device)   # [B,M,T] = P(y=1)
        mask_all = batch["tool_mask"].to(device)     # [B,M,T]

        # image identifiers
        if image_id_key in batch:
            image_ids = batch[image_id_key]
        else:
            B = images.size(0)
            image_ids = list(range(global_image_idx, global_image_idx + B))
            global_image_idx += B

        B, M, _ = preds_all.shape

        for t in range(num_tasks):
            gt = gt_all[:, t]                 # [B]
            tool_probs = preds_all[:, :, t]   # [B,M]
            tool_mask  = mask_all[:, :, t]    # [B,M]

            # only keep samples with ≥1 valid tool
            valid_mask = tool_mask.sum(dim=1) > 0
            if valid_mask.sum().item() == 0:
                continue

            idxs = torch.nonzero(valid_mask, as_tuple=False).squeeze(1)

            # context (cached per task if you want)
            ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, device=device, model=model)
            if ctx_img_feat is not None:
                ctx_img_feat = ctx_img_feat.to(device)
                ctx_gt = ctx_gt.to(device)
                ctx_pred = ctx_pred.to(device)
            else:
                Dx = model.img_dim if hasattr(model, "img_dim") else 512
                ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
                ctx_gt       = torch.zeros((M, 0), device=device)
                ctx_pred     = torch.zeros((M, 0), device=device)

            # router scores
            task_ids = torch.full((B,), t, dtype=torch.long, device=device)
            scores = model(
                images=images,
                text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),
                task_idx=task_ids,
                tool_preds=tool_probs,
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask,
            )  # [B,M]

            scores = scores.masked_fill(tool_mask == 0, -1e9)
            chosen = scores.argmax(dim=1)  # [B]

            for i in idxs.tolist():
                m = int(chosen[i].item())
                p = float(tool_probs[i, m].item())
                y = int(gt[i].item())

                rows.append({
                    "image_id": image_ids[i],
                    "task": int(t),
                    "gt": y,
                    "chosen_tool": m,
                    "pred_prob": p,
                    "pred_label": int(p >= 0.5),
                })

    return pd.DataFrame(rows)


In [None]:
df = routing_dataframe(
    model=model,
    loader=test_loader,
    ctx_mgr=ctx_mgr,
    num_tasks=num_tasks,
    device=device,
    image_id_key="image_id",  # or None if you don’t have one
)

df.head()


Building routing dataframe: 100%|██████████| 10/10 [00:54<00:00,  5.43s/it]


Unnamed: 0,image_id,task,gt,chosen_tool,pred_prob,pred_label
0,0,0,0,10,0.384786,0
1,1,0,0,10,0.389557,0
2,2,0,0,10,0.383806,0
3,3,0,0,10,0.793852,1
4,4,0,0,10,0.262171,0


In [None]:
df['pred_label'] = (df['pred_prob'] >= 0.5).astype(int)
df['pred_correct'] = (df['pred_label'] == df['gt']).astype(int)

In [None]:
df.sort_values(by=['image_id', 'task'])

Unnamed: 0,image_id,task,gt,chosen_tool,pred_prob,pred_label
0,0,0,0,10,0.384786,0
16,0,1,0,4,0.505804,1
32,0,2,0,17,0.140698,0
48,0,3,0,4,0.538294,1
64,0,4,0,14,0.001470,0
...,...,...,...,...,...,...
2759,155,13,0,15,0.000220,0
2771,155,14,0,0,0.079217,0
2783,155,15,0,5,0.503880,1
2795,155,16,1,1,0.177569,0


In [None]:
@torch.no_grad()
def evaluate_random_router_accuracy(
    loader,
    num_tasks,
    device="cpu",
    seed=0,
):
    """
    Computes random-router accuracy, directly comparable to DySTANce Val Acc.

    For each batch:
      - sample a task
      - randomly select a valid tool
      - check binary correctness

    Returns:
        dict with accuracy and sample count
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)

    correct_total = 0
    total_samples = 0

    for batch in tqdm(loader, desc="Random router eval"):
        gt_all = batch["gt"].to(device)              # [B, L]
        preds_all = batch["tool_preds"].to(device)   # [B, M, L]
        mask_all = batch["tool_mask"].to(device)     # [B, M, L]

        B = gt_all.size(0)

        task_idx = torch.randint(
            0, num_tasks, (1,), generator=rng, device=device
        ).item()

        gt = gt_all[:, task_idx]                     # [B]
        preds = preds_all[:, :, task_idx]            # [B, M]
        mask = mask_all[:, :, task_idx]              # [B, M]

        for i in range(B):
            valid_tools = torch.nonzero(mask[i], as_tuple=False).squeeze(-1)

            if valid_tools.numel() == 0:
                total_samples += 1
                continue

            j = valid_tools[
                torch.randint(
                    0, valid_tools.numel(), (1,),
                    generator=rng, device=device
                ).item()
            ]

            pred_label = (preds[i, j] >= 0.5).long()
            correct_total += (pred_label == gt[i]).item()
            total_samples += 1

    return {
        "accuracy": correct_total / max(1, total_samples),
        "num_samples": total_samples,
    }

res = evaluate_random_router_accuracy(
    val_loader,
    num_tasks=num_tasks,
    device=device,
    seed=0,
)
res

Random router eval: 100%|██████████| 17/17 [00:00<00:00, 84.66it/s]


{'accuracy': 0.75, 'num_samples': 264}

In [None]:
@torch.no_grad()
def evaluate_upper_ceiling_accuracy(
    loader,
    num_tasks,
    device="cpu",
):
    """
    Computes the upper-ceiling accuracy:
    sample is correct if ANY valid tool predicts correctly.

    Directly comparable to DySTANce Val Acc.
    """

    correct_total = 0
    total_samples = 0

    for batch in tqdm(loader, desc="Upper ceiling eval"):
        gt_all = batch["gt"].to(device)              # [B, L]
        preds_all = batch["tool_preds"].to(device)   # [B, M, L]
        mask_all = batch["tool_mask"].to(device)     # [B, M, L]

        B = gt_all.size(0)

        # same protocol: random task per batch
        task_idx = torch.randint(0, num_tasks, (1,), device=device).item()

        gt = gt_all[:, task_idx]                     # [B]
        preds = preds_all[:, :, task_idx]            # [B, M]
        mask = mask_all[:, :, task_idx]              # [B, M]

        pred_labels = (preds >= 0.5).long()          # [B, M]
        gt_exp = gt.unsqueeze(1).expand_as(pred_labels)

        correct = (pred_labels == gt_exp) & mask.bool()

        correct_total += correct.any(dim=1).sum().item()
        total_samples += B

    return {
        "accuracy": correct_total / max(1, total_samples),
        "num_samples": total_samples,
    }

res = evaluate_upper_ceiling_accuracy(
    val_loader,
    num_tasks=num_tasks,
    device=device,
)
res


Upper ceiling eval: 100%|██████████| 17/17 [00:00<00:00, 101.13it/s]


{'accuracy': 0.9734848484848485, 'num_samples': 264}

In [None]:
@torch.no_grad()
def evaluate_single_tool_accuracy(
    dataloader,
    task_idx: int,
    tool_idx: int,
    device="cpu",
):
    """
    Accuracy of a fixed tool for a given task.
    """
    correct = 0
    total = 0

    for batch in dataloader:
        gt = batch["gt"][:, task_idx].to(device)               # [B]
        preds = batch["tool_preds"][:, tool_idx, task_idx].to(device)
        mask = batch["tool_mask"][:, tool_idx, task_idx].to(device)

        # Only count samples where the tool is valid
        valid = mask.bool()
        if valid.sum() == 0:
            continue

        pred_labels = (preds[valid] >= 0.5).long()
        correct += (pred_labels == gt[valid]).sum().item()
        total += valid.sum().item()

    return correct / max(1, total)


In [None]:
def evaluate_best_single_tool_oracle(
    dataloader,
    task_idx: int,
    device="cpu",
):
    """
    Chooses the best tool based on validation performance itself.
    Oracle baseline.
    """
    M = dataloader.dataset.M
    accs = []

    for tool_idx in range(M):
        acc = evaluate_single_tool_accuracy(
            dataloader, task_idx, tool_idx, device
        )
        accs.append(acc)

    best_acc = max(accs)
    best_tool = accs.index(best_acc)

    return {
        "best_accuracy": best_acc,
        "best_tool_idx": best_tool,
        "all_tool_accuracies": accs,
    }


In [None]:
def select_best_single_tool_on_train(
    train_loader,
    task_idx: int,
    device="cpu",
):
    """
    Selects the best tool using TRAIN data only.
    """
    M = train_loader.dataset.dataset.M
    accs = []

    for tool_idx in range(M):
        acc = evaluate_single_tool_accuracy(
            train_loader, task_idx, tool_idx, device
        )
        accs.append(acc)

    best_tool = accs.index(max(accs))
    return best_tool, accs


In [None]:
def evaluate_best_single_tool_train_selected(
    train_loader,
    val_loader,
    task_idx: int,
    device="cpu",
):
    """
    Train-selected fixed-tool baseline.
    """
    best_tool, train_accs = select_best_single_tool_on_train(
        train_loader, task_idx, device
    )

    val_acc = evaluate_single_tool_accuracy(
        val_loader, task_idx, best_tool, device
    )

    return {
        "selected_tool": best_tool,
        "train_acc": train_accs[best_tool],
        "val_acc": val_acc,
        "all_train_accs": train_accs,
    }


In [None]:
task_idx = 2  # Pneumonia

oracle = evaluate_best_single_tool_oracle(
    val_loader, task_idx, device
)

train_sel = evaluate_best_single_tool_train_selected(
    train_loader, val_loader, task_idx, device
)

print("Oracle best single tool:")
print(f"  Val Acc: {oracle['best_accuracy']:.4f}")
print(f"  Tool idx: {oracle['best_tool_idx']}")

print("\nTrain-selected best single tool:")
print(f"  Train Acc: {train_sel['train_acc']:.4f}")
print(f"  Val Acc: {train_sel['val_acc']:.4f}")
print(f"  Tool idx: {train_sel['selected_tool']}")


Oracle best single tool:
  Val Acc: 1.0000
  Tool idx: 8

Train-selected best single tool:
  Train Acc: 0.9929
  Val Acc: 0.9924
  Tool idx: 7
