In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights


In [3]:
class InstructionEncoder(nn.Module):
    """
    Encodes the textual instruction q.
    For now: simple embedding + mean pooling.
    In large-scale settings this would be a frozen LLM / CLIP text encoder.
    """
    def __init__(self, vocab_size: int, embed_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids: torch.Tensor):
        """
        Args:
            token_ids: [B, T] integer tokens

        Returns:
            Tensor[B, embed_dim]
        """
        emb = self.embedding(token_ids)          # [B, T, D]
        return emb.mean(dim=1)                   # mean pool over tokens


In [6]:
class ANPToolEncoder(nn.Module):
    """
    Attentive Neural Process encoder for task-conditional tool descriptors.

    Builds a representation z_E^t(p) from a small context set D_E^t.
    """

    def __init__(
        self,
        img_dim: int,
        hidden_dim: int,
        n_heads: int = 4,
    ):
        super().__init__()

        # Encodes individual context elements:
        # [phi_x(x_b) || y_b^t || m_E^t(x_b)] -> hidden
        self.context_proj = nn.Sequential(
            nn.Linear(img_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Self-attention over context elements
        self.self_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=n_heads,
            batch_first=True
        )
        self.norm_ctx = nn.LayerNorm(hidden_dim)

        # Cross-attention projections
        self.W_Q = nn.Linear(hidden_dim, hidden_dim)
        self.W_K = nn.Linear(img_dim, hidden_dim)
        self.W_V = nn.Linear(hidden_dim, hidden_dim)

        self.scale = hidden_dim ** -0.5
        self.norm_out = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        query_embed: torch.Tensor,
        ctx_img_feat: torch.Tensor,
        ctx_gt: torch.Tensor,
        ctx_pred: torch.Tensor,
    ):
        """
        Args:
            query_embed : [B, H]              = u(p)
            ctx_img_feat: [M, C, Dx]          = phi_x(x_b)
            ctx_gt      : [M, C]              = y_b^t
            ctx_pred    : [M, C]              = m_E^t(x_b)

        Returns:
            z_E : [B, M, H]  task-conditional tool descriptors
        """
        M, C, Dx = ctx_img_feat.shape
        B = query_embed.shape[0]

        # ------------------------------------------------------------------
        # 1) Encode context elements
        # ------------------------------------------------------------------
        ctx_input = torch.cat(
            [
                ctx_img_feat,
                ctx_gt.unsqueeze(-1),
                ctx_pred.unsqueeze(-1),
            ],
            dim=-1
        )  # [M, C, Dx+2]

        ctx_emb = self.context_proj(ctx_input)  # [M, C, H]

        # ------------------------------------------------------------------
        # 2) Self-attention over context (per tool)
        # ------------------------------------------------------------------
        ctx_emb_sa, _ = self.self_attn(ctx_emb, ctx_emb, ctx_emb)
        ctx_emb = self.norm_ctx(ctx_emb + ctx_emb_sa)  # residual

        # ------------------------------------------------------------------
        # 3) Cross-attention: query u(p) attends to each toolâ€™s context
        # ------------------------------------------------------------------
        Q = self.W_Q(query_embed)                # [B, H]
        Q = Q.unsqueeze(1).expand(-1, M, -1)     # [B, M, H]

        K = self.W_K(ctx_img_feat)               # [M, C, H]
        V = self.W_V(ctx_emb)                    # [M, C, H]

        # Attention scores: [B, M, C]
        attn_logits = torch.einsum(
            "bmh,mch->bmc", Q, K
        ) * self.scale

        attn = F.softmax(attn_logits, dim=-1)

        # Weighted sum of values -> [B, M, H]
        z = torch.einsum("bmc,mch->bmh", attn, V)

        return self.norm_out(z)


In [7]:
class DySTANceRouter(nn.Module):
    """
    Full DySTANce routing model.

    Given a query (image, instruction, task) and a panel of tools,
    outputs a scalar routing score for each tool.
    """

    def __init__(
        self,
        num_tasks: int,
        vocab_size: int,
        hidden_dim: int = 256,
    ):
        super().__init__()

        # ------------------------------------------------------------
        # Image encoder phi_x
        # ------------------------------------------------------------
        resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.img_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.img_dim = 512

        # ------------------------------------------------------------
        # Instruction encoder phi_q
        # ------------------------------------------------------------
        self.text_encoder = InstructionEncoder(vocab_size, 64)

        # ------------------------------------------------------------
        # Task embedding
        # ------------------------------------------------------------
        self.task_embed = nn.Embedding(num_tasks, 32)

        # ------------------------------------------------------------
        # Prompt fusion u(p)
        # ------------------------------------------------------------
        self.prompt_fusion = nn.Sequential(
            nn.Linear(self.img_dim + 64 + 32, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # ------------------------------------------------------------
        # ANP module psi_E^t(p)
        # ------------------------------------------------------------
        self.anp = ANPToolEncoder(
            img_dim=self.img_dim,
            hidden_dim=hidden_dim,
        )

        # ------------------------------------------------------------
        # Router head g_theta
        # ------------------------------------------------------------
        # Input: [u(p) || z_E || m_E^t(x)]
        self.router_head = nn.Sequential(
            nn.Linear(hidden_dim * 2 + 1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1),
        )

    def extract_img_feat(self, images: torch.Tensor):
        return self.img_encoder(images).flatten(1)

    def forward(
        self,
        images: torch.Tensor,
        text_tokens: torch.Tensor,
        task_idx: torch.Tensor,
        tool_preds: torch.Tensor,
        ctx_img_feat: torch.Tensor,
        ctx_gt: torch.Tensor,
        ctx_pred: torch.Tensor,
        tool_mask: torch.Tensor,
    ):
        """
        Args:
            images     : [B, 3, H, W]
            text_tokens: [B, T]
            task_idx   : [B]
            tool_preds : [B, M]          m_E^t(x)
            ctx_img_feat: [M, C, Dx]
            ctx_gt     : [M, C]
            ctx_pred   : [M, C]
            tool_mask  : [B, M]          1 if tool supports task

        Returns:
            scores : [B, M]
        """

        # ------------------------------------------------------------
        # Build u(p)
        # ------------------------------------------------------------
        img_feat = self.extract_img_feat(images)
        txt_feat = self.text_encoder(text_tokens)
        task_feat = self.task_embed(task_idx)

        u_p = torch.cat([img_feat, txt_feat, task_feat], dim=-1)
        u_p = self.prompt_fusion(u_p)  # [B, H]

        # ------------------------------------------------------------
        # Tool descriptors z_E^t(p)
        # ------------------------------------------------------------
        z_E = self.anp(u_p, ctx_img_feat, ctx_gt, ctx_pred)  # [B, M, H]

        # ------------------------------------------------------------
        # Router head
        # ------------------------------------------------------------
        u_exp = u_p.unsqueeze(1).expand(-1, z_E.size(1), -1)
        tool_preds = tool_preds.unsqueeze(-1)

        router_in = torch.cat([u_exp, z_E, tool_preds], dim=-1)
        scores = self.router_head(router_in).squeeze(-1)  # [B, M]

        # ------------------------------------------------------------
        # Hard mask invalid tools
        # ------------------------------------------------------------
        scores = scores.masked_fill(tool_mask == 0, -1e9)

        return scores


In [8]:
class DySTANceLoss(nn.Module):
    """
    Population comp-sum surrogate loss for DySTANce routing.

    This loss:
    - supports soft costs in [0,1]
    - allows multiple near-optimal tools
    - handles variable panel sizes
    - is compatible with the theory in the paper
    """

    def __init__(
        self,
        surrogate_type: str = "logistic",
        lambda_entropy: float = 0.05,
        eps: float = 1e-7,
    ):
        super().__init__()
        self.surrogate_type = surrogate_type
        self.lambda_entropy = lambda_entropy
        self.eps = eps

    def forward(
        self,
        router_logits: torch.Tensor,  # [B, M]
        tool_costs: torch.Tensor,     # [B, M] in [0,1]
        validity_mask: torch.Tensor,  # [B, M] in {0,1}
    ):
        B, M = router_logits.shape

        # ------------------------------------------------------------
        # 1) Masked softmax over valid tools
        # ------------------------------------------------------------
        masked_logits = router_logits.masked_fill(validity_mask == 0, -1e9)
        pi = F.softmax(masked_logits, dim=1)  # [B, M]

        # ------------------------------------------------------------
        # 2) Effective panel size per sample
        # ------------------------------------------------------------
        m_eff = validity_mask.sum(dim=1, keepdim=True).clamp(min=1.0)

        # ------------------------------------------------------------
        # 3) Cost centering (KEY FIX)
        # ------------------------------------------------------------
        # This allows multiple tools to be "correct"
        # and stabilizes comp-sum weights for soft costs.
        active_costs = tool_costs * validity_mask
        min_cost, _ = torch.min(
            active_costs + (1 - validity_mask) * 1e9,
            dim=1,
            keepdim=True,
        )
        centered_costs = active_costs - min_cost  # >= 0

        # ------------------------------------------------------------
        # 4) Comp-sum weights
        # w_j = sum_{k!=j} c_k - m + 2
        # ------------------------------------------------------------
        sum_costs = centered_costs.sum(dim=1, keepdim=True)
        w = (sum_costs - centered_costs) - m_eff + 2.0

        # ------------------------------------------------------------
        # 5) Surrogate Psi(pi)
        # ------------------------------------------------------------
        if self.surrogate_type == "logistic":
            psi = -torch.log(pi + self.eps)
        elif self.surrogate_type == "mae":
            psi = 1.0 - pi
        else:
            raise ValueError(f"Unknown surrogate: {self.surrogate_type}")

        # ------------------------------------------------------------
        # 6) Aggregate comp-sum loss
        # ------------------------------------------------------------
        loss_per_sample = (w * psi * validity_mask).sum(dim=1)
        loss_main = loss_per_sample.mean()

        # ------------------------------------------------------------
        # 7) Entropy regularization (panel-normalized)
        # ------------------------------------------------------------
        log_pi = torch.log(pi + self.eps)
        entropy = -(pi * log_pi).sum(dim=1)
        loss_entropy = -self.lambda_entropy * entropy.mean()

        total_loss = loss_main + loss_entropy

        return total_loss, {
            "loss_main": loss_main.item(),
            "loss_entropy": loss_entropy.item(),
            "avg_panel_size": m_eff.mean().item(),
            "avg_min_cost": min_cost.mean().item(),
        }


In [9]:
##### TESTS

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def masked_softmax(logits, mask):
    logits = logits.masked_fill(mask == 0, -1e9)
    return F.softmax(logits, dim=-1)

def run_loss(
    loss_fn,
    logits,
    costs,
    mask,
):
    logits = logits.clone().requires_grad_(True)
    loss, logs = loss_fn(logits, costs, mask)
    loss.backward()
    return loss.item(), logits.grad.detach(), logs

In [11]:
def scenario_single_best():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 1.0, 1.0]])  # tool 0 perfect
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_two_good():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 0.0, 1.0]])
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_all_bad():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[1.0, 1.0, 1.0]])
    mask   = torch.tensor([[1.0, 1.0, 1.0]])
    return logits, costs, mask

def scenario_with_mask():
    logits = torch.tensor([[0.0, 0.0, 0.0]])
    costs  = torch.tensor([[0.0, 1.0, 1.0]])
    mask   = torch.tensor([[1.0, 0.0, 1.0]])
    return logits, costs, mask


In [12]:
loss_fn = DySTANceLoss(
    surrogate_type="logistic",
    lambda_entropy=0.0,  # turn OFF entropy for clarity
)


In [13]:
def test_scenario(name, scenario_fn):
    logits, costs, mask = scenario_fn()
    loss, grad, logs = run_loss(loss_fn, logits, costs, mask)

    pi = masked_softmax(logits, mask)

    print(f"\n=== {name} ===")
    print("Costs:", costs.numpy())
    print("Probs:", pi.detach().numpy())
    print("Loss:", loss)
    print("Gradients:", grad.numpy())


In [17]:
test_scenario("Single best tool", scenario_single_best)
test_scenario("Two equally good tools", scenario_two_good)
test_scenario("All tools bad", scenario_all_bad)
test_scenario("Masked tool", scenario_with_mask)

## pass, pass, pass, pass


=== Single best tool ===
Costs: [[0. 1. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: 1.0986120700836182
Gradients: [[-0.66666645  0.33333325  0.33333325]]

=== Two equally good tools ===
Costs: [[0. 0. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: -1.0986120700836182
Gradients: [[-0.33333325 -0.33333325  0.66666645]]

=== All tools bad ===
Costs: [[1. 1. 1.]]
Probs: [[0.33333334 0.33333334 0.33333334]]
Loss: -3.2958362102508545
Gradients: [[0. 0. 0.]]

=== Masked tool ===
Costs: [[0. 1. 1.]]
Probs: [[0.5 0.  0.5]]
Loss: 0.6931469440460205
Gradients: [[-0.49999988  0.          0.49999988]]
