In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
from tqdm import tqdm
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Get repository root directory
# Try multiple methods to find the repo root
cwd = Path.cwd()
if (cwd / 'data').exists():
    REPO_ROOT = cwd
elif (cwd.parent / 'data').exists():
    REPO_ROOT = cwd.parent
else:
    # Fallback: assume we're in dev_notebooks and go up one level
    REPO_ROOT = cwd.parent

import imports

Using device: cuda


In [2]:
# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
DATA_ROOT = REPO_ROOT / "data" / "openi"
LABELS_DIR = DATA_ROOT / "labels"
IMAGES_DIR = DATA_ROOT / "image"
PRED_DIR   = DATA_ROOT / "predictions"

# ------------------------------------------------------------
# Labels (tasks)
# ------------------------------------------------------------
label_names = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]
num_tasks = len(label_names)

# ------------------------------------------------------------
# Tool registry
# ------------------------------------------------------------
registry_all = imports.scan_prediction_files(str(PRED_DIR))

# Example split: train on non-resnet tools
train_tools = [t for t in registry_all["train"] if "resnet" not in t]

train_registry = {t: registry_all["train"][t] for t in train_tools}
val_registry   = {t: registry_all["val"][t]   for t in train_tools}

# ------------------------------------------------------------
# Datasets
# ------------------------------------------------------------
train_dataset_full = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Train.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=train_registry,
    label_names=label_names,
    transform=None,  # assume tensor conversion inside dataset
)

val_dataset = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Valid.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=val_registry,
    label_names=label_names,
    transform=None,
)



In [8]:
ctx_mgr = imports.ContextManager(
    dataset=train_dataset_full,
    context_fraction=0.1,      # 10% context
    examples_per_tool=64,      # B_t
)

train_dataset = ctx_mgr.routing_dataset()


In [9]:
train_loader = DataLoader(
    train_dataset,
    batch_size=25,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=25,
    shuffle=False,
    num_workers=4,
)


In [19]:
import pandas as pd
pd.DataFrame(val_loader.dataset.records)

Unnamed: 0,id,path,gt
0,2548,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
1,1829,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
2,870,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
3,1795,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
4,3159,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
...,...,...,...
259,3565,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
260,2826,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
261,494,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
262,1234,/home/kell6630/repos/DySTANce/data/openi/image...,"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."


In [63]:
id_arr = []
data_arr = []
for i in val_loader.dataset.records:
    id_arr.append(i['id'])
    data_arr.append(i['gt'])

In [64]:
gt_df = pd.DataFrame(torch.stack(data_arr), columns=label_names, index=id_arr)

In [79]:
gt_df

Unnamed: 0,Atelectasis,Consolidation,Infiltration,Pneumothorax,Edema,Emphysema,Fibrosis,Effusion,Pneumonia,Pleural_Thickening,Cardiomegaly,Nodule,Mass,Hernia,Lung Lesion,Fracture,Lung Opacity,Enlarged Cardiomediastinum
2548,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1829,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
870,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1795,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3159,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3565,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2826,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
494,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
1234,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [127]:
pred_dfs = []
tool_metrics = {}
for i,k in enumerate(val_loader.dataset.tool_preds):
    tool_name = val_dataset.tool_names[i]
    pred_df = pd.DataFrame(k).T.rename(columns={j: label_names[j] for j in range(len(label_names))}) > 0.5
    tool_acc = (pred_df.loc[gt_df.index].values.flatten() == gt_df.values.flatten()).mean()

    label_wise_acc = {}
    for label_name in label_names:
        label_wise_acc[label_name] = (pred_df.loc[gt_df.index, label_name].values == gt_df[label_name].values).mean()


    tool_metrics[tool_name] = {
        'pred_df': pred_df,
        'tool_acc': tool_acc,
        'label_wise_acc': label_wise_acc
    }

In [140]:
data = []
for tool_name in tool_metrics.keys():
    data.append(tool_metrics[tool_name]['label_wise_acc'].values())
label_wise_acc_df = pd.DataFrame(data, columns=label_names, index=tool_metrics.keys())


In [141]:
label_wise_acc_df

Unnamed: 0,Atelectasis,Consolidation,Infiltration,Pneumothorax,Edema,Emphysema,Fibrosis,Effusion,Pneumonia,Pleural_Thickening,Cardiomegaly,Nodule,Mass,Hernia,Lung Lesion,Fracture,Lung Opacity,Enlarged Cardiomediastinum
densenet121_res224_all,0.579545,0.647727,0.590909,0.742424,0.878788,0.496212,0.409091,0.806818,0.848485,0.715909,0.655303,0.439394,0.537879,0.939394,0.863636,0.617424,0.666667,0.734848
densenet121_res224_chex,0.651515,0.700758,1.0,0.962121,0.431818,1.0,1.0,0.780303,0.799242,0.984848,0.564394,1.0,1.0,1.0,0.776515,0.42803,0.590909,0.556818
densenet121_res224_mimic_ch,0.488636,0.242424,1.0,0.064394,0.284091,1.0,1.0,0.575758,0.284091,0.984848,0.462121,1.0,1.0,1.0,0.242424,0.291667,0.507576,0.215909
densenet121_res224_mimic_nb,0.651515,0.272727,1.0,0.253788,0.340909,1.0,1.0,0.715909,0.155303,0.984848,0.537879,1.0,1.0,1.0,0.231061,0.534091,0.522727,0.564394
densenet121_res224_nih,0.579545,0.473485,0.810606,0.655303,0.916667,0.477273,0.530303,0.772727,0.689394,0.44697,0.659091,0.75,0.67803,0.556818,0.909091,0.920455,0.560606,0.886364
densenet121_res224_pc,0.507576,0.526515,0.488636,0.590909,0.80303,0.776515,0.931818,0.80303,0.443182,0.859848,0.681818,0.556818,0.666667,0.875,0.909091,0.515152,0.560606,0.886364
densenet121_res224_rsna,0.890152,0.962121,1.0,0.977273,0.984848,1.0,1.0,0.912879,0.746212,0.984848,0.606061,1.0,1.0,1.0,0.909091,0.920455,0.57197,0.886364
densenet_medical_mae_pt_openi,0.897727,0.962121,0.992424,0.977273,0.984848,1.0,1.0,0.950758,0.984848,0.984848,0.617424,0.996212,0.984848,1.0,0.909091,0.920455,0.560606,0.886364
densenet_mocov2_pt_openi,0.897727,0.962121,1.0,0.977273,0.984848,1.0,1.0,0.950758,0.984848,0.984848,0.617424,1.0,0.988636,1.0,0.909091,0.920455,0.560606,0.886364
evax_base_cxr__pt_openi,0.901515,0.962121,0.981061,0.977273,0.984848,1.0,1.0,0.935606,0.984848,0.984848,0.628788,0.992424,0.977273,0.996212,0.909091,0.920455,0.560606,0.886364


In [5]:
model = imports.DySTANceRouter(
    num_tasks=num_tasks,
    vocab_size=1000,   # dummy vocab size for now
    hidden_dim=256,
).to(device)

# criterion = imports.DySTANceLoss(
#     surrogate_type="logistic",
#     lambda_entropy=0.05,
# )

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4,
)


def build_context_tensors(ctx_mgr, task_idx, device):
    """
    Builds task-conditional context tensors for all tools.

    Returns:
        ctx_img_feat : [M, C, Dx] on `device`
        ctx_gt       : [M, C]     on `device`
        ctx_pred     : [M, C]     on `device`
    """
    ctx_img_feats = []
    ctx_gts = []
    ctx_preds = []

    M = ctx_mgr.dataset.M # number of tools
    C = ctx_mgr.examples_per_tool # number of examples

    for tool_idx in range(M):
        ctx = ctx_mgr.sample_context(tool_idx, task_idx)

        if ctx is None:
            # No valid context for this tool-task pair
            ctx_img_feats.append(
                torch.zeros(C, model.img_dim, device=device)
            )
            ctx_gts.append(
                torch.zeros(C, device=device)
            )
            ctx_preds.append(
                torch.zeros(C, device=device)
            )
        else:
            imgs, gt, preds = ctx

            imgs = imgs.to(device)
            gt = gt.to(device)
            preds = preds.to(device)

            with torch.no_grad():
                feats = model.extract_img_feat(imgs)  # [C, Dx]

            ctx_img_feats.append(feats)
            ctx_gts.append(gt)
            ctx_preds.append(preds)

    return (
        torch.stack(ctx_img_feats, dim=0),  # [M, C, Dx]
        torch.stack(ctx_gts, dim=0),        # [M, C]
        torch.stack(ctx_preds, dim=0),      # [M, C]
    )

In [6]:
def train_one_epoch(model, loader, ctx_mgr, optimizer, criterion):
    model.train()

    total_loss = 0.0
    total_batches = 0

    for batch in tqdm(loader, desc="Training"):
        images = batch["image"].to(device)        # [B, 3, H, W]
        gt_all = batch["gt"].to(device)           # [B, L]
        preds_all = batch["tool_preds"].to(device)  # [B, M, L]
        mask_all = batch["tool_mask"].to(device)    # [B, M, L]

        B = images.size(0)

        # ------------------------------------------------------------
        # 1) Sample a task uniformly
        # ------------------------------------------------------------
        task_idx = random.randint(0, num_tasks - 1)
        task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

        # Task-conditional slices
        gt = gt_all[:, task_idx]                 # [B]
        tool_preds = preds_all[:, :, task_idx]   # [B, M]
        tool_mask  = mask_all[:, :, task_idx]    # [B, M]

        # ------------------------------------------------------------
        # 2) Build context for this task
        # ------------------------------------------------------------
        ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
            ctx_mgr, task_idx, device
        )

        # ------------------------------------------------------------
        # 3) Forward pass
        # ------------------------------------------------------------
        scores = model(
            images=images,
            text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text
            task_idx=task_ids,
            tool_preds=tool_preds,
            ctx_img_feat=ctx_img_feat,
            ctx_gt=ctx_gt,
            ctx_pred=ctx_pred,
            tool_mask=tool_mask,
        )

        # ------------------------------------------------------------
        # 4) Compute costs (classification task)
        # c_E = 1 - confidence on true label
        # ------------------------------------------------------------
        tool_costs = 1.0 - tool_preds  # [B, M]

        # ------------------------------------------------------------
        # 5) Loss + backward
        # ------------------------------------------------------------
        loss, logs = criterion(scores, tool_costs, tool_mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_batches += 1

    return total_loss / max(1, total_batches)


In [7]:
from typing import Optional

def train_one_epoch_all_tasks(model, loader, ctx_mgr, optimizer, criterion, task_weights: Optional[torch.Tensor] = None):
    """
    Train one epoch but evaluate the comp-sum loss for every task in each batch,
    then average the losses across tasks and take a single optimization step.

    Args:
        model: the DySTANce model (model.train() will be called)
        loader: DataLoader yielding batches with keys:
                "image" -> [B,3,H,W],
                "gt"    -> [B,L],
                "tool_preds" -> [B, M, L],
                "tool_mask"  -> [B, M, L]
        ctx_mgr: ContextManager used to build ANP contexts (build_context_tensors wrapper)
        optimizer: optimizer to step
        criterion: DySTANceLoss instance (callable -> (loss, logs))
        task_weights: optional Tensor [L] to weight tasks (on CPU or device). If None, uniform.
    Returns:
        avg_loss_per_batch: scalar float
    """
    model.train()
    total_loss = 0.0
    total_batches = 0

    # Optional task weights (defaults to uniform)
    if task_weights is None:
        task_weights = torch.ones((num_tasks,), dtype=torch.float32, device=device)
    else:
        # move to device
        task_weights = task_weights.to(device).float()

    for batch in tqdm(loader, desc="Training"):
        images = batch["image"].to(device)                 # [B, 3, H, W]
        gt_all = batch["gt"].to(device)                   # [B, L]
        preds_all = batch["tool_preds"].to(device)        # [B, M, L]
        mask_all = batch["tool_mask"].to(device)          # [B, M, L]

        B = images.size(0)

        # Accumulate loss over tasks for this minibatch
        loss_sum_tasks = 0.0
        weight_sum = 0.0

        # Loop over tasks (labels)
        for t in range(num_tasks):
            # ---------------------------
            # 1) Task slices
            # ---------------------------
            task_ids = torch.full((B,), t, device=device, dtype=torch.long)
            gt = gt_all[:, t]                     # [B]
            tool_preds = preds_all[:, :, t]       # [B, M]
            tool_mask  = mask_all[:, :, t]        # [B, M]

            # Quick check: if *none* of the tools are valid for this task across the batch,
            # then there's nothing to route to — skip or handle specially.
            # Here we skip such tasks (no contribution to loss) to avoid degenerate behavior.
            if tool_mask.sum() == 0:
                # skip this task (no valid tools in the panel for any sample)
                continue

            # ---------------------------
            # 2) Build context tensors for this task
            #    (build_context_tensors must return CPU tensors; move to device)
            # ---------------------------
            ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, device=device)

            # Ensure returned tensors are moved to the correct device
            # (some implementations already return on device; this is idempotent)
            if ctx_img_feat is not None:
                ctx_img_feat = ctx_img_feat.to(device)
                ctx_gt = ctx_gt.to(device)
                ctx_pred = ctx_pred.to(device)
            else:
                # If NO context exists for this task (all tools unseen for this task),
                # we should still provide reasonable placeholders to the model.
                # We set empty (zero) contexts with zero sizes: M=number of tools, C=0 -> handle in ANP.
                # For simplicity, create zero-sized contexts of expected shapes:
                # We assume ctx_mgr and model expect shape [M, C, Dx] etc. If your ANP
                # has special handling for B_t=0, ensure it accepts those tensors.
                M = preds_all.shape[1]
                Dx = model.img_dim if hasattr(model, "img_dim") else 512
                # make shapes consistent with your ANP — here we create C=0 context
                ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
                ctx_gt       = torch.zeros((M, 0), device=device)
                ctx_pred     = torch.zeros((M, 0), device=device)

            # ---------------------------
            # 3) Forward pass for this task
            # ---------------------------
            scores = model(
                images=images,
                text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text or real tokens
                task_idx=task_ids,
                tool_preds=tool_preds,
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask,
            )  # [B, M]

            # ---------------------------
            # 4) Build costs (task-specific)
            #    For classification: c_E = 1 - confidence on true label (soft proxy)
            # ---------------------------
            tool_costs = 1.0 - tool_preds  # [B, M]

            # ---------------------------
            # 5) Compute loss for this task
            # ---------------------------
            loss_t, logs = criterion(scores, tool_costs, tool_mask)  # scalar loss_t tensor + logs dict

            # Weight the task (uniform by default)
            w_t = float(task_weights[t].item())
            loss_sum_tasks = loss_sum_tasks + (w_t * loss_t)
            weight_sum += w_t

        # If no tasks contributed (extremely unlikely), skip update
        if weight_sum == 0:
            continue

        # Average across tasks
        loss_batch = loss_sum_tasks / weight_sum

        # Backprop and step once per minibatch
        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()

        total_loss += float(loss_batch.detach().cpu().item())
        total_batches += 1

    avg_loss = total_loss / max(1, total_batches)
    return avg_loss


In [8]:
@torch.no_grad()
def evaluate(model, loader, ctx_mgr):
    """
    Evaluates DySTANce router on:
      - average regret
      - per-task accuracy (binary)

    Task is sampled per batch (same as training).
    """
    model.eval()

    total_regret = 0.0
    total_samples = 0
    correct_total = 0

    for batch in tqdm(loader, desc="Validation"):
        images = batch["image"].to(device)
        gt_all = batch["gt"].to(device)                # [B, L]
        preds_all = batch["tool_preds"].to(device)     # [B, M, L]
        mask_all = batch["tool_mask"].to(device)       # [B, M, L]

        B = images.size(0)

        # ------------------------------------------------------------
        # Sample a task (same protocol as training)
        # ------------------------------------------------------------
        task_idx = random.randint(0, num_tasks - 1)
        task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

        gt = gt_all[:, task_idx]                       # [B]
        tool_preds = preds_all[:, :, task_idx]         # [B, M]
        tool_mask  = mask_all[:, :, task_idx]          # [B, M]

        # ------------------------------------------------------------
        # Build task-conditional context
        # ------------------------------------------------------------
        ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
            ctx_mgr, task_idx, device
        )

        # ------------------------------------------------------------
        # Forward pass
        # ------------------------------------------------------------
        scores = model(
            images,
            torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text
            task_ids,
            tool_preds,
            ctx_img_feat,
            ctx_gt,
            ctx_pred,
            tool_mask,
        )

        # ------------------------------------------------------------
        # Routing decision
        # ------------------------------------------------------------
        chosen = scores.argmax(dim=1)  # [B]

        # ------------------------------------------------------------
        # Regret computation
        # ------------------------------------------------------------
        costs = 1.0 - tool_preds       # [B, M]

        chosen_cost = costs[torch.arange(B), chosen]
        oracle_cost = costs.masked_fill(tool_mask == 0, 1e9).min(dim=1).values

        total_regret += (chosen_cost - oracle_cost).sum().item()

        # ------------------------------------------------------------
        # Accuracy computation (binary)
        # ------------------------------------------------------------
        chosen_preds = tool_preds[torch.arange(B), chosen]  # [B]
        chosen_labels = (chosen_preds >= 0.5).long()

        correct_total += (chosen_labels == gt).sum().item()
        total_samples += B

    avg_regret = total_regret / max(1, total_samples)
    accuracy = correct_total / max(1, total_samples)

    return {
        "avg_regret": avg_regret,
        "accuracy": accuracy,
        "num_samples": total_samples,
    }


In [9]:
@torch.no_grad()
def evaluate_all_tasks(model, loader, ctx_mgr):
    """
    Evaluate DySTANce over EVERY task per batch (no random single-task sampling).

    Returns a dict with:
      - avg_regret      : average regret across all considered (sample,task) pairs
      - accuracy        : overall accuracy (router chosen tool vs GT)
      - per_task_acc    : list of per-task accuracies
      - random_acc      : random-router baseline accuracy
      - upper_bound_acc : oracle upper-ceiling accuracy
      - num_pairs       : number of considered (sample,task) pairs
    """
    model.eval()

    # Accumulators
    total_regret = 0.0
    total_correct = 0
    total_pairs = 0

    # Per-task accumulators
    per_task_correct = [0 for _ in range(num_tasks)]
    per_task_pairs = [0 for _ in range(num_tasks)]

    # Baselines
    random_correct = 0
    upper_correct = 0

    # Cache contexts per task to avoid repeated builds
    ctx_cache = {}

    for batch in tqdm(loader, desc="Validation"):
        images = batch["image"].to(device)            # [B, C, H, W]
        gt_all  = batch["gt"].to(device)             # [B, L]
        preds_all = batch["tool_preds"].to(device)   # [B, M, L]
        mask_all  = batch["tool_mask"].to(device)    # [B, M, L]

        B = images.shape[0]

        # For each task, evaluate all samples in batch (but skip samples with no valid tool)
        for t in range(num_tasks):
            # Task slices
            gt = gt_all[:, t]                       # [B]
            tool_preds = preds_all[:, :, t]         # [B, M]
            tool_mask  = mask_all[:, :, t]          # [B, M] {0,1}

            # Determine which samples in this batch have at least one valid tool
            valid_tool_counts = tool_mask.sum(dim=1)        # [B]
            valid_samples_mask = (valid_tool_counts > 0)   # [B] bool

            if valid_samples_mask.sum().item() == 0:
                # no sample in this batch has any valid tool for this task -> skip
                continue

            # Build / fetch cached context for task t (move to device)
            if t in ctx_cache:
                ctx_img_feat, ctx_gt, ctx_pred = ctx_cache[t]
            else:
                ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(ctx_mgr, t, device=device)
                # move to device (build_context_tensors may return CPU tensors)
                if ctx_img_feat is not None:
                    ctx_img_feat = ctx_img_feat.to(device)
                    ctx_gt = ctx_gt.to(device)
                    ctx_pred = ctx_pred.to(device)
                else:
                    # Create zero-sized contexts if none exist (ANP should handle B_t=0)
                    M = preds_all.shape[1]
                    Dx = model.img_dim if hasattr(model, "img_dim") else 512
                    ctx_img_feat = torch.zeros((M, 0, Dx), device=device)
                    ctx_gt       = torch.zeros((M, 0), device=device)
                    ctx_pred     = torch.zeros((M, 0), device=device)

                ctx_cache[t] = (ctx_img_feat, ctx_gt, ctx_pred)

            # Build task ids for the batch
            task_ids = torch.full((B,), t, dtype=torch.long, device=device)

            # Forward pass (model expected to handle masking internally)
            scores = model(
                images=images,
                text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text (or real tokens if available)
                task_idx=task_ids,
                tool_preds=tool_preds,
                ctx_img_feat=ctx_img_feat,
                ctx_gt=ctx_gt,
                ctx_pred=ctx_pred,
                tool_mask=tool_mask,
            )  # [B, M]

            # Ensure invalid tools have very low score (safety)
            scores = scores.masked_fill(tool_mask == 0, -1e9)

            # Choose best tool per sample
            chosen = scores.argmax(dim=1)  # [B] (may point to invalid tool if all -inf; but we filtered such samples)

            # Consider only samples with at least one valid tool
            idxs = torch.nonzero(valid_samples_mask, as_tuple=False).squeeze(1)  # [B_valid]
            if idxs.numel() == 0:
                continue

            # Subset arrays to valid samples
            chosen_v = chosen[idxs]                          # [B_valid]
            gt_v = gt[idxs]                                 # [B_valid]
            preds_v = tool_preds[idxs]                      # [B_valid, M]
            mask_v  = tool_mask[idxs]                       # [B_valid, M]

            # Costs: c = 1 - confidence (soft proxy)
            costs_v = 1.0 - preds_v                         # [B_valid, M]

            # Chosen cost and oracle cost
            chosen_cost = costs_v[torch.arange(idxs.numel(), device=device), chosen_v]  # [B_valid]

            # Oracle: min cost among valid tools
            inf_mask = (~(mask_v.bool())).float() * 1e9
            costs_for_oracle = costs_v + inf_mask
            oracle_cost = costs_for_oracle.min(dim=1).values  # [B_valid]

            total_regret += (chosen_cost - oracle_cost).sum().item()

            # Accuracy: interpret chosen tool's prediction as binary label (>=0.5)
            chosen_pred_probs = preds_v[torch.arange(idxs.numel(), device=device), chosen_v]  # [B_valid]
            chosen_labels = (chosen_pred_probs >= 0.5).long()
            correct = (chosen_labels == gt_v).sum().item()
            total_correct += int(correct)
            total_pairs += int(idxs.numel())

            # Per-task accounting
            per_task_correct[t] += int(correct)
            per_task_pairs[t] += int(idxs.numel())

            # --- Baselines for these valid samples --------------------------------
            # Random-router baseline: choose uniformly among valid tools for each sample
            # We implement this in vectorized-ish manner:
            Bv, M = preds_v.shape
            # Create list of valid tool indices per sample
            # We'll loop here across Bv (Bv is smallish per batch)
            for i in range(Bv):
                valid_indices = torch.nonzero(mask_v[i].bool(), as_tuple=False).squeeze(1)
                if valid_indices.numel() == 0:
                    # Shouldn't happen due to earlier filtering
                    continue
                # Random pick
                rnd_idx = int(valid_indices[torch.randint(0, valid_indices.numel(), (1,)).item()].item())
                rnd_prob = preds_v[i, rnd_idx].item()
                rnd_label = 1 if rnd_prob >= 0.5 else 0
                if rnd_label == int(gt_v[i].item()):
                    random_correct += 1

                # Upper-ceiling: any valid tool predicts correctly?
                # If any valid tool's binary prediction equals gt, count as correct
                # (treat prob>=0.5 as predicting label=1)
                valid_probs = preds_v[i, valid_indices]
                valid_preds_bin = (valid_probs >= 0.5).long()
                if (valid_preds_bin == int(gt_v[i].item())).any():
                    upper_correct += 1
            # ----------------------------------------------------------------------

    # After all batches
    avg_regret = total_regret / max(1, total_pairs)
    accuracy = total_correct / max(1, total_pairs)
    per_task_acc = [ (per_task_correct[t] / per_task_pairs[t]) if per_task_pairs[t] > 0 else None
                     for t in range(num_tasks) ]

    random_acc = random_correct / max(1, total_pairs)
    upper_acc  = upper_correct / max(1, total_pairs)

    return {
        "avg_regret": avg_regret,
        "accuracy": accuracy,
        "per_task_acc": per_task_acc,
        "random_acc": random_acc,
        "upper_acc": upper_acc,
        "num_pairs": total_pairs,
    }


In [10]:
@torch.no_grad()
def compute_upper_ceiling_accuracy(
    dataloader,
    task_idx: int,
    device="cpu",
):
    """
    Computes the upper-ceiling accuracy for a single task.

    Upper-ceiling = fraction of samples for which at least one
    valid tool predicts the correct label.

    Args:
        dataloader : DataLoader yielding OpenIRoutedDataset batches
        task_idx   : int, which task (label) to evaluate
        device     : torch device

    Returns:
        ceiling_acc : float in [0,1]
        stats       : dict with additional diagnostics
    """
    total = 0
    count_ceiling = 0
    count_any_valid = 0

    for batch in tqdm(dataloader, desc=f"Ceiling task {task_idx}"):
        gt = batch["gt"][:, task_idx].to(device)            # [B]
        preds = batch["tool_preds"][:, :, task_idx].to(device)  # [B, M]
        mask = batch["tool_mask"][:, :, task_idx].to(device)    # [B, M]

        B, M = preds.shape

        # Predicted labels per tool (binary classification)
        pred_labels = (preds >= 0.5).long()  # [B, M]

        # Ground truth expanded
        gt_exp = gt.unsqueeze(1).expand(-1, M)  # [B, M]

        # Correct predictions per tool
        correct = (pred_labels == gt_exp) & (mask.bool())  # [B, M]

        # For each sample: does ANY tool get it right?
        any_correct = correct.any(dim=1)  # [B]

        # For sanity: does sample have ANY valid tool?
        any_valid = mask.any(dim=1)       # [B]

        count_ceiling += any_correct.sum().item()
        count_any_valid += any_valid.sum().item()
        total += B

    ceiling_acc = count_ceiling / max(1, total)

    stats = {
        "total_samples": total,
        "samples_with_any_valid_tool": count_any_valid,
        "fraction_with_any_valid_tool": count_any_valid / max(1, total),
    }

    return ceiling_acc, stats

for task_idx in range(num_tasks):
    ceiling_acc, stats = compute_upper_ceiling_accuracy(
        val_loader,
        task_idx=task_idx,
        device=device,
    )

    print(f"Upper-ceiling accuracy (task {task_idx}): {ceiling_acc:.4f}")
    print(stats)


Ceiling task 0: 100%|██████████| 11/11 [00:00<00:00, 57.00it/s]


Upper-ceiling accuracy (task 0): 0.9924
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 1: 100%|██████████| 11/11 [00:00<00:00, 62.55it/s]


Upper-ceiling accuracy (task 1): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 2: 100%|██████████| 11/11 [00:00<00:00, 64.49it/s]


Upper-ceiling accuracy (task 2): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 3: 100%|██████████| 11/11 [00:00<00:00, 62.00it/s]


Upper-ceiling accuracy (task 3): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 4: 100%|██████████| 11/11 [00:00<00:00, 73.07it/s]


Upper-ceiling accuracy (task 4): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 5: 100%|██████████| 11/11 [00:00<00:00, 59.29it/s]


Upper-ceiling accuracy (task 5): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 6: 100%|██████████| 11/11 [00:00<00:00, 61.34it/s]


Upper-ceiling accuracy (task 6): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 7: 100%|██████████| 11/11 [00:00<00:00, 60.44it/s]


Upper-ceiling accuracy (task 7): 0.9924
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 8: 100%|██████████| 11/11 [00:00<00:00, 70.51it/s]


Upper-ceiling accuracy (task 8): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 9: 100%|██████████| 11/11 [00:00<00:00, 73.83it/s]


Upper-ceiling accuracy (task 9): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 10: 100%|██████████| 11/11 [00:00<00:00, 64.80it/s]


Upper-ceiling accuracy (task 10): 0.9432
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 11: 100%|██████████| 11/11 [00:00<00:00, 68.26it/s]


Upper-ceiling accuracy (task 11): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 12: 100%|██████████| 11/11 [00:00<00:00, 63.14it/s]


Upper-ceiling accuracy (task 12): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 13: 100%|██████████| 11/11 [00:00<00:00, 70.38it/s]


Upper-ceiling accuracy (task 13): 1.0000
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 14: 100%|██████████| 11/11 [00:00<00:00, 64.26it/s]


Upper-ceiling accuracy (task 14): 0.9432
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 15: 100%|██████████| 11/11 [00:00<00:00, 65.10it/s]


Upper-ceiling accuracy (task 15): 0.8598
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 16: 100%|██████████| 11/11 [00:00<00:00, 74.01it/s]


Upper-ceiling accuracy (task 16): 0.9205
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}


Ceiling task 17: 100%|██████████| 11/11 [00:00<00:00, 67.92it/s]

Upper-ceiling accuracy (task 17): 0.8750
{'total_samples': 264, 'samples_with_any_valid_tool': 264, 'fraction_with_any_valid_tool': 1.0}





In [11]:
@torch.no_grad()
def compute_random_router_accuracy(
    dataloader,
    task_idx: int,
    device="cpu",
    seed: int = 0,
):
    """
    Computes accuracy of a random router for a single task.

    For each sample:
      - uniformly sample one VALID tool
      - use its prediction as the output

    Args:
        dataloader : DataLoader yielding OpenIRoutedDataset batches
        task_idx   : int, which task (label) to evaluate
        device     : torch device
        seed       : random seed for reproducibility

    Returns:
        rand_acc : float in [0,1]
        stats    : dict
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)

    total = 0
    correct_total = 0
    samples_with_valid = 0

    for batch in tqdm(dataloader, desc=f"Random router task {task_idx}"):
        gt = batch["gt"][:, task_idx].to(device)             # [B]
        preds = batch["tool_preds"][:, :, task_idx].to(device)  # [B, M]
        mask = batch["tool_mask"][:, :, task_idx].to(device)    # [B, M]

        B, M = preds.shape

        for i in range(B):
            valid_tools = torch.nonzero(mask[i], as_tuple=False).squeeze(-1)

            if valid_tools.numel() == 0:
                # No valid tool: cannot predict (count as incorrect)
                total += 1
                continue

            samples_with_valid += 1

            # Uniform random choice among valid tools
            j = valid_tools[
                torch.randint(
                    low=0,
                    high=valid_tools.numel(),
                    size=(1,),
                    generator=rng,
                    device=device,
                ).item()
            ]

            # Binary prediction
            pred_label = (preds[i, j] >= 0.5).long()

            correct_total += (pred_label == gt[i]).item()
            total += 1

    rand_acc = correct_total / max(1, total)

    stats = {
        "total_samples": total,
        "samples_with_valid_tool": samples_with_valid,
        "fraction_with_valid_tool": samples_with_valid / max(1, total),
    }

    return rand_acc, stats


for task_idx in range(num_tasks):
    rand_acc, stats = compute_random_router_accuracy(
        val_loader,
        task_idx=task_idx,
        device=device,
    )
    print(f"Random router accuracy (task {task_idx}): {rand_acc:.4f}")
    print(stats)


Random router task 0: 100%|██████████| 11/11 [00:00<00:00, 60.53it/s]


Random router accuracy (task 0): 0.7652
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 1: 100%|██████████| 11/11 [00:00<00:00, 59.63it/s]


Random router accuracy (task 1): 0.7462
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 2: 100%|██████████| 11/11 [00:00<00:00, 65.85it/s]


Random router accuracy (task 2): 0.8409
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 3: 100%|██████████| 11/11 [00:00<00:00, 70.67it/s]


Random router accuracy (task 3): 0.7197
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 4: 100%|██████████| 11/11 [00:00<00:00, 69.32it/s]


Random router accuracy (task 4): 0.8144
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 5: 100%|██████████| 11/11 [00:00<00:00, 60.54it/s]


Random router accuracy (task 5): 0.8182
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 6: 100%|██████████| 11/11 [00:00<00:00, 59.42it/s]


Random router accuracy (task 6): 0.8447
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 7: 100%|██████████| 11/11 [00:00<00:00, 59.22it/s]


Random router accuracy (task 7): 0.8485
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 8: 100%|██████████| 11/11 [00:00<00:00, 67.34it/s]


Random router accuracy (task 8): 0.7424
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 9: 100%|██████████| 11/11 [00:00<00:00, 61.35it/s]


Random router accuracy (task 9): 0.8485
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 10: 100%|██████████| 11/11 [00:00<00:00, 64.81it/s]


Random router accuracy (task 10): 0.6136
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 11: 100%|██████████| 11/11 [00:00<00:00, 64.39it/s]


Random router accuracy (task 11): 0.8106
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 12: 100%|██████████| 11/11 [00:00<00:00, 68.20it/s]


Random router accuracy (task 12): 0.8447
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 13: 100%|██████████| 11/11 [00:00<00:00, 64.10it/s]


Random router accuracy (task 13): 0.8939
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 14: 100%|██████████| 11/11 [00:00<00:00, 65.03it/s]


Random router accuracy (task 14): 0.5152
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 15: 100%|██████████| 11/11 [00:00<00:00, 60.94it/s]


Random router accuracy (task 15): 0.4470
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 16: 100%|██████████| 11/11 [00:00<00:00, 63.74it/s]


Random router accuracy (task 16): 0.5758
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}


Random router task 17: 100%|██████████| 11/11 [00:00<00:00, 69.33it/s]

Random router accuracy (task 17): 0.4621
{'total_samples': 264, 'samples_with_valid_tool': 264, 'fraction_with_valid_tool': 1.0}





In [12]:
criterion = imports.DySTANceLoss_v2(
    surrogate_type="logistic",
    lambda_entropy=0.1,
)

In [13]:
num_epochs = 100

for epoch in range(num_epochs):
    train_loss = train_one_epoch_all_tasks(
        model, train_loader, ctx_mgr, optimizer, criterion
    )

    val_metrics = evaluate_all_tasks(model, val_loader, ctx_mgr)

    print(
        f"[Epoch {epoch:02d}] "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Regret: {val_metrics['avg_regret']:.4f} | "
        f"Val Acc: {val_metrics['accuracy']:.4f}"
    )



Training: 100%|██████████| 34/34 [05:28<00:00,  9.66s/it]
Validation: 100%|██████████| 11/11 [00:11<00:00,  1.06s/it]


[Epoch 00] Train Loss: -143.3824 | Val Regret: 0.0158 | Val Acc: 0.3262


Training: 100%|██████████| 34/34 [05:31<00:00,  9.74s/it]
Validation: 100%|██████████| 11/11 [00:11<00:00,  1.00s/it]


[Epoch 01] Train Loss: -147.4650 | Val Regret: 0.0148 | Val Acc: 0.3287


Training: 100%|██████████| 34/34 [05:38<00:00,  9.95s/it]
Validation: 100%|██████████| 11/11 [00:11<00:00,  1.02s/it]


[Epoch 02] Train Loss: -143.4751 | Val Regret: 0.0045 | Val Acc: 0.3129


Training: 100%|██████████| 34/34 [05:38<00:00,  9.95s/it]
Validation: 100%|██████████| 11/11 [00:10<00:00,  1.00it/s]


[Epoch 03] Train Loss: -143.5812 | Val Regret: 0.0028 | Val Acc: 0.3089


Training:   9%|▉         | 3/34 [00:32<05:31, 10.68s/it]


KeyboardInterrupt: 

In [13]:
@torch.no_grad()
def evaluate_random_router_accuracy(
    loader,
    num_tasks,
    device="cpu",
    seed=0,
):
    """
    Computes random-router accuracy, directly comparable to DySTANce Val Acc.

    For each batch:
      - sample a task
      - randomly select a valid tool
      - check binary correctness

    Returns:
        dict with accuracy and sample count
    """
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)

    correct_total = 0
    total_samples = 0

    for batch in tqdm(loader, desc="Random router eval"):
        gt_all = batch["gt"].to(device)              # [B, L]
        preds_all = batch["tool_preds"].to(device)   # [B, M, L]
        mask_all = batch["tool_mask"].to(device)     # [B, M, L]

        B = gt_all.size(0)

        task_idx = torch.randint(
            0, num_tasks, (1,), generator=rng, device=device
        ).item()

        gt = gt_all[:, task_idx]                     # [B]
        preds = preds_all[:, :, task_idx]            # [B, M]
        mask = mask_all[:, :, task_idx]              # [B, M]

        for i in range(B):
            valid_tools = torch.nonzero(mask[i], as_tuple=False).squeeze(-1)

            if valid_tools.numel() == 0:
                total_samples += 1
                continue

            j = valid_tools[
                torch.randint(
                    0, valid_tools.numel(), (1,),
                    generator=rng, device=device
                ).item()
            ]

            pred_label = (preds[i, j] >= 0.5).long()
            correct_total += (pred_label == gt[i]).item()
            total_samples += 1

    return {
        "accuracy": correct_total / max(1, total_samples),
        "num_samples": total_samples,
    }

res = evaluate_random_router_accuracy(
    val_loader,
    num_tasks=num_tasks,
    device=device,
    seed=0,
)
res

Random router eval: 100%|██████████| 17/17 [00:00<00:00, 93.57it/s]


{'accuracy': 0.75, 'num_samples': 264}

In [15]:
@torch.no_grad()
def evaluate_upper_ceiling_accuracy(
    loader,
    num_tasks,
    device="cpu",
):
    """
    Computes the upper-ceiling accuracy:
    sample is correct if ANY valid tool predicts correctly.

    Directly comparable to DySTANce Val Acc.
    """

    correct_total = 0
    total_samples = 0

    for batch in tqdm(loader, desc="Upper ceiling eval"):
        gt_all = batch["gt"].to(device)              # [B, L]
        preds_all = batch["tool_preds"].to(device)   # [B, M, L]
        mask_all = batch["tool_mask"].to(device)     # [B, M, L]

        B = gt_all.size(0)

        # same protocol: random task per batch
        task_idx = torch.randint(0, num_tasks, (1,), device=device).item()

        gt = gt_all[:, task_idx]                     # [B]
        preds = preds_all[:, :, task_idx]            # [B, M]
        mask = mask_all[:, :, task_idx]              # [B, M]

        pred_labels = (preds >= 0.5).long()          # [B, M]
        gt_exp = gt.unsqueeze(1).expand_as(pred_labels)

        correct = (pred_labels == gt_exp) & mask.bool()

        correct_total += correct.any(dim=1).sum().item()
        total_samples += B

    return {
        "accuracy": correct_total / max(1, total_samples),
        "num_samples": total_samples,
    }

res = evaluate_upper_ceiling_accuracy(
    val_loader,
    num_tasks=num_tasks,
    device=device,
)
res


Upper ceiling eval: 100%|██████████| 17/17 [00:00<00:00, 97.14it/s]


{'accuracy': 0.9621212121212122, 'num_samples': 264}

In [21]:
@torch.no_grad()
def evaluate_single_tool_accuracy(
    dataloader,
    task_idx: int,
    tool_idx: int,
    device="cpu",
):
    """
    Accuracy of a fixed tool for a given task.
    """
    correct = 0
    total = 0

    for batch in dataloader:
        gt = batch["gt"][:, task_idx].to(device)               # [B]
        preds = batch["tool_preds"][:, tool_idx, task_idx].to(device)
        mask = batch["tool_mask"][:, tool_idx, task_idx].to(device)

        # Only count samples where the tool is valid
        valid = mask.bool()
        if valid.sum() == 0:
            continue

        pred_labels = (preds[valid] >= 0.5).long()
        correct += (pred_labels == gt[valid]).sum().item()
        total += valid.sum().item()

    return correct / max(1, total)


In [22]:
def evaluate_best_single_tool_oracle(
    dataloader,
    task_idx: int,
    device="cpu",
):
    """
    Chooses the best tool based on validation performance itself.
    Oracle baseline.
    """
    M = dataloader.dataset.M
    accs = []

    for tool_idx in range(M):
        acc = evaluate_single_tool_accuracy(
            dataloader, task_idx, tool_idx, device
        )
        accs.append(acc)

    best_acc = max(accs)
    best_tool = accs.index(best_acc)

    return {
        "best_accuracy": best_acc,
        "best_tool_idx": best_tool,
        "all_tool_accuracies": accs,
    }


In [23]:
def select_best_single_tool_on_train(
    train_loader,
    task_idx: int,
    device="cpu",
):
    """
    Selects the best tool using TRAIN data only.
    """
    M = train_loader.dataset.dataset.M
    accs = []

    for tool_idx in range(M):
        acc = evaluate_single_tool_accuracy(
            train_loader, task_idx, tool_idx, device
        )
        accs.append(acc)

    best_tool = accs.index(max(accs))
    return best_tool, accs


In [24]:
def evaluate_best_single_tool_train_selected(
    train_loader,
    val_loader,
    task_idx: int,
    device="cpu",
):
    """
    Train-selected fixed-tool baseline.
    """
    best_tool, train_accs = select_best_single_tool_on_train(
        train_loader, task_idx, device
    )

    val_acc = evaluate_single_tool_accuracy(
        val_loader, task_idx, best_tool, device
    )

    return {
        "selected_tool": best_tool,
        "train_acc": train_accs[best_tool],
        "val_acc": val_acc,
        "all_train_accs": train_accs,
    }


In [27]:
task_idx = 2  # Pneumonia

oracle = evaluate_best_single_tool_oracle(
    val_loader, task_idx, device
)

train_sel = evaluate_best_single_tool_train_selected(
    train_loader, val_loader, task_idx, device
)

print("Oracle best single tool:")
print(f"  Val Acc: {oracle['best_accuracy']:.4f}")
print(f"  Tool idx: {oracle['best_tool_idx']}")

print("\nTrain-selected best single tool:")
print(f"  Train Acc: {train_sel['train_acc']:.4f}")
print(f"  Val Acc: {train_sel['val_acc']:.4f}")
print(f"  Tool idx: {train_sel['selected_tool']}")


Oracle best single tool:
  Val Acc: 1.0000
  Tool idx: 8

Train-selected best single tool:
  Train Acc: 0.9929
  Val Acc: 0.9924
  Tool idx: 7
