In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights

In [2]:
# This isn't really needed at this point
# for our project, we’re assuming that the user inputs an image and a question and that 
# this is 100% always correctly routed to the right pool of tools (like an image and they say “segment” so it routes to the segmentation pool of tools)

class InstructionEncoder(nn.Module):
    """
    Encodes the textual instruction q.
    For now: simple embedding + mean pooling.
    In large-scale settings this would be a frozen LLM / CLIP text encoder.
    """
    def __init__(self, vocab_size: int, embed_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids: torch.Tensor):
        """
        Args:
            token_ids: [B, T] integer tokens

        Returns:
            Tensor[B, embed_dim]
        """
        emb = self.embedding(token_ids)          # [B, T, D]
        return emb.mean(dim=1)                   # mean pool over tokens


In [3]:
class ANPToolEncoder(nn.Module):
    """
    Attentive Neural Process encoder for task-conditional tool descriptors.

    Builds a representation z_E^t(p) from a small context set D_E^t.
    """

    def __init__(
        self,
        img_dim: int,
        hidden_dim: int,
        n_heads: int = 4,
    ):
        super().__init__()

        # Encodes individual context elements:
        # [phi_x(x_b) || y_b^t || m_E^t(x_b)] -> hidden
        self.context_proj = nn.Sequential(
            nn.Linear(img_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Self-attention over context elements
        self.self_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=n_heads,
            batch_first=True
        )
        self.norm_ctx = nn.LayerNorm(hidden_dim)

        # Cross-attention projections
        self.W_Q = nn.Linear(hidden_dim, hidden_dim)
        self.W_K = nn.Linear(img_dim, hidden_dim)
        self.W_V = nn.Linear(hidden_dim, hidden_dim)

        self.scale = hidden_dim ** -0.5
        self.norm_out = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        query_embed: torch.Tensor,
        ctx_img_feat: torch.Tensor,
        ctx_gt: torch.Tensor,
        ctx_pred: torch.Tensor,
    ):
        """
        Args:
            query_embed : [B, H]              = u(p)
            ctx_img_feat: [M, C, Dx]          = phi_x(x_b)
            ctx_gt      : [M, C]              = y_b^t
            ctx_pred    : [M, C]              = m_E^t(x_b)

        Returns:
            z_E : [B, M, H]  task-conditional tool descriptors
        """
        M, C, Dx = ctx_img_feat.shape
        B = query_embed.shape[0]

        # ------------------------------------------------------------------
        # 1) Encode context elements
        # ------------------------------------------------------------------
        ctx_input = torch.cat(
            [
                ctx_img_feat,
                ctx_gt.unsqueeze(-1),
                ctx_pred.unsqueeze(-1),
            ],
            dim=-1
        )  # [M, C, Dx+2]

        ctx_emb = self.context_proj(ctx_input)  # [M, C, H]

        # ------------------------------------------------------------------
        # 2) Self-attention over context (per tool)
        # ------------------------------------------------------------------
        ctx_emb_sa, _ = self.self_attn(ctx_emb, ctx_emb, ctx_emb)
        ctx_emb = self.norm_ctx(ctx_emb + ctx_emb_sa)  # residual

        # ------------------------------------------------------------------
        # 3) Cross-attention: query u(p) attends to each tool’s context
        # ------------------------------------------------------------------
        Q = self.W_Q(query_embed)                # [B, H]
        Q = Q.unsqueeze(1).expand(-1, M, -1)     # [B, M, H]

        K = self.W_K(ctx_img_feat)               # [M, C, H]
        V = self.W_V(ctx_emb)                    # [M, C, H]

        # Attention scores: [B, M, C]
        attn_logits = torch.einsum(
            "bmh,mch->bmc", Q, K
        ) * self.scale

        attn = F.softmax(attn_logits, dim=-1)

        # Weighted sum of values -> [B, M, H]
        z = torch.einsum("bmc,mch->bmh", attn, V)

        return self.norm_out(z)


In [4]:
class DySTANceRouter(nn.Module):
    """
    Full DySTANce routing model.

    Given a query (image, instruction, task(?)) and a panel of tools,
    outputs a scalar routing score for each tool.
    """

    def __init__(
        self,
        num_tasks: int,
        vocab_size: int,
        hidden_dim: int = 256,
    ):
        super().__init__()

        # ------------------------------------------------------------
        # Image encoder phi_x
        # ------------------------------------------------------------
        resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.img_encoder = nn.Sequential(*list(resnet.children())[:-1]) # everything except the last layer (head)
        self.img_dim = 512

        # ------------------------------------------------------------
        # Instruction encoder phi_q
        # ------------------------------------------------------------
        ### CHECK THIS! I don't think it's needed ATM
        self.text_encoder = InstructionEncoder(vocab_size, 64)

        # ------------------------------------------------------------
        # Task embedding
        # ------------------------------------------------------------
        self.task_embed = nn.Embedding(num_tasks, 32)

        # ------------------------------------------------------------
        # Prompt fusion u(p)
        # ------------------------------------------------------------
        self.prompt_fusion = nn.Sequential(
            nn.Linear(self.img_dim + 64 + 32, hidden_dim), # need to fix hard coded hidden dims here
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # ------------------------------------------------------------
        # ANP module psi_E^t(p)
        # info: anp is attentive neural process
        # ------------------------------------------------------------
        self.anp = ANPToolEncoder(
            img_dim=self.img_dim,
            hidden_dim=hidden_dim,
        )

        # ------------------------------------------------------------
        # Router head g_theta
        # ------------------------------------------------------------
        # Input: [u(p) || z_E || m_E^t(x)]
        self.router_head = nn.Sequential(
            nn.Linear(hidden_dim * 2 + 1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1),
        )

    def extract_img_feat(self, images: torch.Tensor):
        return self.img_encoder(images).flatten(1)

    def forward(
        self,
        images: torch.Tensor,
        text_tokens: torch.Tensor,
        task_idx: torch.Tensor,
        tool_preds: torch.Tensor,
        ctx_img_feat: torch.Tensor,
        ctx_gt: torch.Tensor,
        ctx_pred: torch.Tensor,
        tool_mask: torch.Tensor,
    ):
        """
        Args:
            images     : [B, 3, H, W]
            text_tokens: [B, T]
            task_idx   : [B] (pathology)
            tool_preds : [B, M]          m_E^t(x)
            ctx_img_feat: [M, C, Dx]
            ctx_gt     : [M, C]
            ctx_pred   : [M, C]
            tool_mask  : [B, M]          1 if tool supports task

        Returns:
            scores : [B, M]
        """

        # ------------------------------------------------------------
        # Build u(p) i.e. the prompt fusion - embedded information of what is being asked (image, question, pathology(?))
        # ------------------------------------------------------------
        img_feat = self.extract_img_feat(images)
        txt_feat = self.text_encoder(text_tokens)
        task_feat = self.task_embed(task_idx)   # the best tool for the task (pathology/label) depends on the label

        # concat the image features, text features, and task features
        # prompt fusion is information on what is being asked of the system 
        u_p = torch.cat([img_feat, txt_feat, task_feat], dim=-1)
        u_p = self.prompt_fusion(u_p)  # [B, H]

        # ------------------------------------------------------------
        # Tool descriptors z_E^t(p)
        # ------------------------------------------------------------

        # attentive neural process with:
        # query: u(p) (prompt fusion above)
        # context: ctx_img_feat, ctx_gt, ctx_pred
        z_E = self.anp(u_p, ctx_img_feat, ctx_gt, ctx_pred)  # [B, M, H]

        # ------------------------------------------------------------
        # Router head
        # ------------------------------------------------------------
        u_exp = u_p.unsqueeze(1).expand(-1, z_E.size(1), -1)
        tool_preds = tool_preds.unsqueeze(-1)

        router_in = torch.cat([u_exp, z_E, tool_preds], dim=-1)
        scores = self.router_head(router_in).squeeze(-1)  # [B, M]

        # ------------------------------------------------------------
        # Hard mask invalid tools
        # ------------------------------------------------------------
        scores = scores.masked_fill(tool_mask == 0, -1e9)

        return scores


In [5]:
#### INVESTIGATE

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
from tqdm import tqdm
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Get repository root directory
# Try multiple methods to find the repo root
cwd = Path.cwd()
if (cwd / 'data').exists():
    REPO_ROOT = cwd
elif (cwd.parent / 'data').exists():
    REPO_ROOT = cwd.parent
else:
    # Fallback: assume we're in dev_notebooks and go up one level
    REPO_ROOT = cwd.parent

import imports

# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
DATA_ROOT = REPO_ROOT / "data" / "openi"
LABELS_DIR = DATA_ROOT / "labels"
IMAGES_DIR = DATA_ROOT / "image"
PRED_DIR   = DATA_ROOT / "predictions"

# ------------------------------------------------------------
# Labels (tasks)
# ------------------------------------------------------------
label_names = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]
num_tasks = len(label_names)

# ------------------------------------------------------------
# Tool registry
# ------------------------------------------------------------
registry_all = imports.scan_prediction_files(str(PRED_DIR))

# Example split: train on non-resnet tools
train_tools = [t for t in registry_all["train"] if "resnet" not in t]

train_registry = {t: registry_all["train"][t] for t in train_tools}
val_registry   = {t: registry_all["val"][t]   for t in train_tools}

# ------------------------------------------------------------
# Datasets
# ------------------------------------------------------------
train_dataset_full = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Train.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=train_registry,
    label_names=label_names,
    transform=None,  # assume tensor conversion inside dataset
)

val_dataset = imports.OpenIRoutedDataset(
    label_csv=str(LABELS_DIR / "Valid.csv"),
    images_dir=str(IMAGES_DIR),
    predictions_registry=val_registry,
    label_names=label_names,
    transform=None,
)

ctx_mgr = imports.ContextManager(
    dataset=train_dataset_full,
    context_fraction=0.1,      # 10% context
    examples_per_tool=32,      # B_t
)

train_dataset = ctx_mgr.routing_dataset()

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


Using device: cuda


In [6]:
for batch in train_loader:
    break

In [7]:
def build_context_tensors(ctx_mgr, task_idx, device):
    """
    Builds task-conditional context tensors for all tools.

    Returns:
        ctx_img_feat : [M, C, Dx] on `device`
        ctx_gt       : [M, C]     on `device`
        ctx_pred     : [M, C]     on `device`
    """
    ctx_img_feats = []
    ctx_gts = []
    ctx_preds = []

    M = ctx_mgr.dataset.M # number of tools
    C = ctx_mgr.examples_per_tool # number of examples

    for tool_idx in range(M):
        ctx = ctx_mgr.sample_context(tool_idx, task_idx)

        if ctx is None:
            # No valid context for this tool-task pair
            ctx_img_feats.append(
                torch.zeros(C, model.img_dim, device=device)
            )
            ctx_gts.append(
                torch.zeros(C, device=device)
            )
            ctx_preds.append(
                torch.zeros(C, device=device)
            )
        else:
            imgs, gt, preds = ctx

            imgs = imgs.to(device)
            gt = gt.to(device)
            preds = preds.to(device)

            with torch.no_grad():
                feats = model.extract_img_feat(imgs)  # [C, Dx]

            ctx_img_feats.append(feats)
            ctx_gts.append(gt)
            ctx_preds.append(preds)

    return (
        torch.stack(ctx_img_feats, dim=0),  # [M, C, Dx]
        torch.stack(ctx_gts, dim=0),        # [M, C]
        torch.stack(ctx_preds, dim=0),      # [M, C]
    )

In [8]:
device = 'cpu'

model = DySTANceRouter(num_tasks, 1000, 256)
model.eval()

images = batch["image"].to(device)        # [B, 3, H, W]
gt_all = batch["gt"].to(device)           # [B, L] (here L is 18)
preds_all = batch["tool_preds"].to(device)  # [B, M, L] (here M is 14 -- num tools train_loader.dataset.dataset.tool_names)
mask_all = batch["tool_mask"].to(device)    # [B, M, L] (1 when tool supports task -- see 01_dataloading.ipynb)

B = images.size(0)
print(B)

# ------------------------------------------------------------
# 1) Sample a task uniformly -- ATM code works by randomly sampling a task for each batch during training
#                            -- this can be quite noisy - probably better to do over all tasks at once and average
#                            -- i.e., treat it as a binary relevance problem
# ------------------------------------------------------------
task_idx = random.randint(0, num_tasks - 1)
print(f"task_idx: {task_idx}")
task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)
print(f"task_ids: {task_ids}")

# Task-conditional slices
gt = gt_all[:, task_idx]                 # [B]
tool_preds = preds_all[:, :, task_idx]   # [B, M]
tool_mask  = mask_all[:, :, task_idx]    # [B, M]

# ------------------------------------------------------------
# 2) Build context for this task (see code above)
#    -- not very happy with this
#    -- it encodes the context images using the feature extractor (same one for the query image)
#    -- but it doesn't embed the gt or preds using nn.Embedding() as expected, just concats the 1/0 
#            with the embedded query image. L2D-Pop does not do it this way....
# ------------------------------------------------------------
ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
    ctx_mgr, task_idx, device
)
print(f"ctx_img_feat.size(): {ctx_img_feat.size()} (should be [M, C, Dx])") # [M, C, Dx]
print(f"ctx_gt.size(): {ctx_gt.size()} (should be [M, C])") # [M, C]
print(f"ctx_pred.size(): {ctx_pred.size()} (should be [M, C])") # [M, C]

16
task_idx: 9
task_ids: tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
ctx_img_feat.size(): torch.Size([14, 32, 512]) (should be [M, C, Dx])
ctx_gt.size(): torch.Size([14, 32]) (should be [M, C])
ctx_pred.size(): torch.Size([14, 32]) (should be [M, C])


In [9]:
### FORWARD LOOP IN TRAINING
## this bit just encodes the query information (image, question, pathology(?))

#### dealing with the query
# ------------------------------------------------------------
# Build u(p) i.e. the prompt fusion - embedded information of what is being asked (image, question, pathology(????))
# ------------------------------------------------------------
img_feat = model.extract_img_feat(images)
print(f"img_feat.size(): {img_feat.size()} (should be [B, Dx])") # [B, Dx]
txt_feat = model.text_encoder(torch.zeros((B, 1), dtype=torch.long, device=device))
print(f"txt_feat.size(): {txt_feat.size()} (should be [B, Dq])") # [B, Dq]
task_feat = model.task_embed(task_ids)   # the best tool for the task (pathology/label) depends on the label
print(f"task_feat.size(): {task_feat.size()} (should be [B, Dt])") # [B, Dt]

# concat the image features, text features, and task features
# prompt fusion is information on what is being asked of the system 
u_p = torch.cat([img_feat, txt_feat, task_feat], dim=-1)
print(f"u_p.size(): {u_p.size()} (should be [B, Dp+Dq+Dt])") # [B, Dp]
embedded_u_p = model.prompt_fusion(u_p)  # [B, H]
print(f"embedded_u_p.size(): {embedded_u_p.size()} (should be [B, H])") # [B, H]

img_feat.size(): torch.Size([16, 512]) (should be [B, Dx])
txt_feat.size(): torch.Size([16, 64]) (should be [B, Dq])
task_feat.size(): torch.Size([16, 32]) (should be [B, Dt])
u_p.size(): torch.Size([16, 608]) (should be [B, Dp+Dq+Dt])
embedded_u_p.size(): torch.Size([16, 256]) (should be [B, H])


![alt text](l2dpop_anp.png "Title")

In [10]:
### next bit is the ANP module -- model.anp = ANPToolEncoder

#    self.anp = ANPToolEncoder(
#             img_dim=self.img_dim,
#             hidden_dim=hidden_dim,
#         )

    # def forward(
    #     self,
    #     query_embed: torch.Tensor,
    #     ctx_img_feat: torch.Tensor,
    #     ctx_gt: torch.Tensor,
    #     ctx_pred: torch.Tensor,
    # ):

### ....
    #  z_E = self.anp(u_p, ctx_img_feat, ctx_gt, ctx_pred)  # [B, M, H]


M, C, Dx = ctx_img_feat.shape
B = u_p.shape[0]
print(f"M: {M}, C: {C}, Dx: {Dx}, B: {B}")
# ------------------------------------------------------------------
# 1) Encode context elements
# ------------------------------------------------------------------
ctx_input = torch.cat(
    [
        ctx_img_feat, # [M, C, Dx]
        ctx_gt.unsqueeze(-1), # [M, C, 1]
        ctx_pred.unsqueeze(-1), # [M, C, 1]
    ],
    dim=-1
)  # [M, C, Dx+2]
print(f"ctx_input.size(): {ctx_input.size()} (should be [M, C, Dx+2])") # [M, C, Dx+2]

# Encodes individual context elements:
# [phi_x(x_b) || y_b^t || m_E^t(x_b)] -> hidden
ctx_emb = model.anp.context_proj(ctx_input)  # [M, C, H]
print(f"ctx_emb.size(): {ctx_emb.size()} (should be [M, C, H])") # [M, C, H]


# ------------------------------------------------------------------
# 2) Self-attention over context (per tool)
# The point of self-attention here:
# Its purpose is to let context points interact with each other before summarization.
#  - Which context points are representative of the expert’s skills
#  - Which ones are redundant, noisy, or contradictory
#  - Whether there are clusters of competence (e.g., “this expert is good on cars but bad on animals”)
#   Without self-attention, each context point is encoded independently and then averaged. 
#   With self-attention, the embedding of a context point can change depending on what other context points exist.

# After SA, we apply an nn.LayerNorm() https://docs.pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
# on ctx_emb + ctx_emb_sa (normalized to a stable scale)
# removes the dependency on the number of context points
# ------------------------------------------------------------------
ctx_emb_sa, _ = model.anp.self_attn(ctx_emb, ctx_emb, ctx_emb)
ctx_emb = model.anp.norm_ctx(ctx_emb + ctx_emb_sa)  # residual
print(f"cntx_emb_sa.size(): {ctx_emb_sa.size()} (should be [M, C, H])") # [M, C, H]
print(f"ctx_emb.size(): {ctx_emb.size()} (should be [M, C, H])") # [M, C, H]

# ------------------------------------------------------------------
# 3) Cross-attention: query u(p) attends to each tool’s context
# ------------------------------------------------------------------
Q =  model.anp.W_Q(embedded_u_p)                # [B, H]
Q = Q.unsqueeze(1).expand(-1, M, -1)     # [B, M, H]

K = model.anp.W_K(ctx_img_feat)               # [M, C, H]
V = model.anp.W_V(ctx_emb)                    # [M, C, H]

# Attention scores: [B, M, C]
attn_logits = torch.einsum(
    "bmh,mch->bmc", Q, K
) * model.anp.scale

attn = F.softmax(attn_logits, dim=-1)

# Weighted sum of values -> [B, M, H]
z = torch.einsum("bmc,mch->bmh", attn, V)

z_E = model.anp.norm_out(z)
print(f"z_E.size(): {z_E.size()} (should be [B, M, H])") # [B, M, H]


M: 14, C: 32, Dx: 512, B: 16
ctx_input.size(): torch.Size([14, 32, 514]) (should be [M, C, Dx+2])
ctx_emb.size(): torch.Size([14, 32, 256]) (should be [M, C, H])
cntx_emb_sa.size(): torch.Size([14, 32, 256]) (should be [M, C, H])
ctx_emb.size(): torch.Size([14, 32, 256]) (should be [M, C, H])
z_E.size(): torch.Size([16, 14, 256]) (should be [B, M, H])


In [11]:
# ------------------------------------------------------------
# Router head -- can now condition the rejector on this additional information
# ------------------------------------------------------------
u_exp = embedded_u_p.unsqueeze(1).expand(-1, z_E.size(1), -1)
tool_preds = tool_preds.unsqueeze(-1)
print(f"u_exp.size(): {u_exp.size()} (should be [B, M, Dp+Dq+Dt])") # [B, M, Dp+Dq+Dt]
print(f"tool_preds.size(): {tool_preds.size()} (should be [B, M, 1])") # [B, M, 1]

router_in = torch.cat([u_exp, z_E, tool_preds], dim=-1)
scores = model.router_head(router_in).squeeze(-1)  # [B, M]
print(f"scores.size(): {scores.size()} (should be [B, M])") # [B, M]

u_exp.size(): torch.Size([16, 14, 256]) (should be [B, M, Dp+Dq+Dt])
tool_preds.size(): torch.Size([16, 14, 1]) (should be [B, M, 1])
scores.size(): torch.Size([16, 14]) (should be [B, M])


In [15]:
from typing import Tuple

class DySTANceLoss(nn.Module):
    """
    Population comp-sum surrogate loss for DySTANce routing.

    Improvements vs. earlier version:
      - explicit renormalization of pi over valid tools
      - entropy computed only over valid tools, optionally normalized by log(panel_size)
      - robust handling of degenerate panels
      - preserves original comp-sum formulation but applies per-sample cost-centering
        (helps allow multiple near-optimal tools)
    """

    def __init__(
        self,
        surrogate_type: str = "logistic",   # "logistic" or "mae"
        lambda_entropy: float = 0.05,
        entropy_normalize_by_log: bool = True,  # normalize entropy by log(m_eff)
        eps: float = 1e-8,
    ):
        super().__init__()
        self.surrogate_type = surrogate_type
        self.lambda_entropy = lambda_entropy
        self.eps = eps
        self.entropy_normalize_by_log = entropy_normalize_by_log

    def forward(
        self,
        router_logits: torch.Tensor,  # [B, M]
        tool_costs: torch.Tensor,     # [B, M] in [0,1] (lower better)
        validity_mask: torch.Tensor,  # [B, M] in {0,1}
    ) -> Tuple[torch.Tensor, dict]:
        """
        Returns:
            total_loss: scalar tensor
            info: dict of python floats for logging
        """
        B, M = router_logits.shape
        device = router_logits.device

        # ----------------------------
        # 1) Masked softmax -> pi over valid tools
        # ----------------------------
        very_neg = -1e9
        masked_logits = router_logits.masked_fill(validity_mask == 0, very_neg)
        pi = F.softmax(masked_logits, dim=1)  # [B, M]

        # Explicitly zero-out invalid slots and renormalize to avoid numerical leakage
        pi = pi * validity_mask
        row_sums = pi.sum(dim=1, keepdim=True)
        # If a row has sum==0 (no valid tools), keep a uniform small mass on valid_mask (degenerate).
        # But usually row_sums>0; add eps to avoid div by zero.
        pi = pi / (row_sums + self.eps)

        # ----------------------------
        # 2) Effective panel size per sample
        # ----------------------------
        m_eff = validity_mask.sum(dim=1, keepdim=True)  # [B,1]
        # Prevent degenerate panels from causing NaNs later; but we also log this condition.
        m_eff_clamped = torch.clamp(m_eff, min=1.0)

        # ----------------------------
        # 3) Cost centering (subtract per-sample min among active tools)
        # ----------------------------
        # Zero out invalid costs, then set invalid positions to large +ve so min ignores them.
        big = 1e9
        active_costs = tool_costs * validity_mask  # zeros at invalid positions
        # prepare for min: invalid -> +big so min picks among actual actives
        costs_for_min = active_costs + (1.0 - validity_mask) * big
        min_cost, _ = torch.min(costs_for_min, dim=1, keepdim=True)  # [B,1], min among actives
        # In degenerate rows (no actives) min_cost will be big; clamp to zero for safety
        min_cost = torch.where(min_cost > big / 2.0, torch.zeros_like(min_cost), min_cost)

        centered_costs = active_costs - min_cost  # now >= 0 (for active positions), invalid remain 0

        # Optional: you could also scale by (max-min) to keep ranges bounded, but centering suffices.
        # ----------------------------
        # 4) Comp-sum weights w_j = sum_{k!=j} c_k - m + 2
        # ----------------------------
        sum_centered = centered_costs.sum(dim=1, keepdim=True)  # [B,1]
        w = (sum_centered - centered_costs) - m_eff_clamped + 2.0  # [B,M]
        # w for invalid entries will be masked out downstream

        # ----------------------------
        # 5) Surrogate Psi(pi)
        # ----------------------------
        if self.surrogate_type == "logistic":
            psi = -torch.log(pi + self.eps)
        elif self.surrogate_type == "mae":
            psi = 1.0 - pi
        else:
            raise ValueError(f"Unknown surrogate_type={self.surrogate_type}")

        # ----------------------------
        # 6) Aggregate comp-sum loss (mask invalid tools)
        #    L_i = sum_j w_ij * Psi(pi_ij)  over valid j
        # ----------------------------
        element = w * psi * validity_mask
        loss_per_sample = element.sum(dim=1)  # [B]
        loss_main = loss_per_sample.mean()

        # ----------------------------
        # 7) Entropy regularization (computed over valid tools only)
        #    You can normalize entropy per-sample by either m_eff or log(m_eff).
        # ----------------------------
        # Compute entropy only on valid tools
        log_pi = torch.log(pi + self.eps)
        entropy_per_row = -(pi * log_pi * validity_mask).sum(dim=1, keepdim=True)  # [B,1]

        if self.entropy_normalize_by_log:
            # Normalize by log(m_eff) to get a value approx in [0,1] (if m_eff>=2)
            # When m_eff==1, log(1)=0 -> avoid dividing by 0; we set denom to 1 in that case.
            denom = torch.log(m_eff_clamped + 1e-8)
            denom = torch.where(denom == 0.0, torch.ones_like(denom), denom)
            entropy_norm = entropy_per_row / (denom + self.eps)
        else:
            # simple divide by panel size
            entropy_norm = entropy_per_row / (m_eff_clamped + self.eps)

        loss_entropy = -self.lambda_entropy * entropy_norm.mean()

        total_loss = loss_main + loss_entropy

        info = {
            "loss_main": float(loss_main.detach().cpu().item()),
            "loss_entropy": float(loss_entropy.detach().cpu().item()),
            "avg_panel_size": float(m_eff.mean().detach().cpu().item()),
            "avg_min_cost": float(min_cost.mean().detach().cpu().item()),
        }

        return total_loss, info

In [9]:
##### TESTS

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def masked_softmax(logits, mask):
    logits = logits.masked_fill(mask == 0, -1e9)
    return F.softmax(logits, dim=-1)

def run_loss(
    loss_fn,
    logits,
    costs,
    mask,
):
    logits = logits.clone().requires_grad_(True)
    loss, logs = loss_fn(logits, costs, mask)
    loss.backward()
    return loss.item(), logits.grad.detach(), logs

In [11]:
def scenario_single_best():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 1.0, 1.0]])  # tool 0 perfect
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_two_good():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 0.0, 1.0]])
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_all_bad():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[1.0, 1.0, 1.0]])
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_with_mask():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 1.0, 1.0]])
    mask   = torch.tensor([[1.0, 0.0, 1.0]])
    return logits, costs, mask


In [12]:
loss_fn = DySTANceLoss(
    surrogate_type="logistic",
    lambda_entropy=0.0,  # turn OFF entropy for clarity
)


In [13]:
def test_scenario(name, scenario_fn):
    logits, costs, mask = scenario_fn()
    loss, grad, logs = run_loss(loss_fn, logits, costs, mask)

    pi = masked_softmax(logits, mask)

    print(f"\n=== {name} ===")
    print("Costs:", costs.numpy())
    print("Probs:", pi.detach().numpy())
    print("Loss:", loss)
    print("Gradients:", grad.numpy())


In [17]:
test_scenario("Single best tool", scenario_single_best)
test_scenario("Two equally good tools", scenario_two_good)
test_scenario("All tools bad", scenario_all_bad)
test_scenario("Masked tool", scenario_with_mask)

## pass, pass, pass, pass


=== Single best tool ===
Costs: [[0. 1. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: 1.0986120700836182
Gradients: [[-0.66666645  0.33333325  0.33333325]]

=== Two equally good tools ===
Costs: [[0. 0. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: -1.0986120700836182
Gradients: [[-0.33333325 -0.33333325  0.66666645]]

=== All tools bad ===
Costs: [[1. 1. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: -3.2958362102508545
Gradients: [[0. 0. 0.]]

=== Masked tool ===
Costs: [[0. 1. 1.]]
Probs: [[0.5 0.  0.5]]
Loss: 0.6931469440460205
Gradients: [[-0.49999988  0.          0.49999988]]
