In [1]:
import os, csv, re
import numpy as np
import torch
from typing import Dict, List, Optional
from pathlib import Path

# Get repository root directory
# Try multiple methods to find the repo root
cwd = Path.cwd()
if (cwd / 'data').exists():
    REPO_ROOT = cwd
elif (cwd.parent / 'data').exists():
    REPO_ROOT = cwd.parent
else:
    # Fallback: assume we're in dev_notebooks and go up one level
    REPO_ROOT = cwd.parent


In [2]:
def infer_split_from_filename(fname: str) -> Optional[str]:
    """
    Helper to figure out the split from the filename
    """
    if fname.endswith('test.csv'):
        return 'test'
    elif fname.endswith('val.csv'):
        return 'val'
    elif fname.endswith('train.csv'):
        return 'train'
    else:
        import warnings
        warnings.warn(f"Could not infer split from filename: {fname}", UserWarning)
        return None

In [3]:
infer_split_from_filename('dada')



In [None]:
def scan_prediction_files(pred_dir: str) -> Dict[str, Dict[str, str]]:
    """
    Helper to scan prediction files, infer the split from the filename, and return 
    registry of the form:   
        registry[split][tool_name] = csv_path
    """
    registry = {"train": {}, "val": {}, "test": {}}
    for fn in os.listdir(pred_dir):
        if not fn.endswith(".csv"):
            continue
        split = infer_split_from_filename(fn)
        if split is None:
            continue
        stem = os.path.splitext(fn)[0]
        tool = re.sub(rf"(_)?{split}(_)?$", "", stem, flags=re.IGNORECASE).strip("_")
        registry[split][tool] = os.path.join(pred_dir, fn)
    return registry

reg = scan_prediction_files(str(REPO_ROOT / 'data' / 'openi' / 'predictions'))
reg

{'train': {'resnet_mgca_pt_openi': '/home/kell6630/repos/DySTANce/data/openi/predictions/resnet_mgca_pt_openi_train.csv',
  'densenet121_res224_chex': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet121_res224_chex_train.csv',
  'densenet121_res224_all': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet121_res224_all_train.csv',
  'densenet_medical_mae_pt_openi': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet_medical_mae_pt_openi_train.csv',
  'densenet_mocov2_pt_openi': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet_mocov2_pt_openi_train.csv',
  'densenet121_res224_mimic_nb': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet121_res224_mimic_nb_train.csv',
  'densenet121_res224_nih': '/home/kell6630/repos/DySTANce/data/openi/predictions/densenet121_res224_nih_train.csv',
  'resnet_biovil_pt_openi': '/home/kell6630/repos/DySTANce/data/openi/predictions/resnet_biovil_pt_openi_train.csv',
  'densenet121_res224_mimic_c

In [None]:
import pandas as pd
def read_predictions_csv(
    csv_path: str,
    label_names: List[str],
    id_candidates=("id", "filename"),
) -> Dict[str, np.ndarray]:
    """
    Takes in a csv path and a list of label names, and returns a dict of image_id -> [L] float array
    This is for the multi-label task, where we treat it as l independent binary classifiers. 
    Note: Missing or unsupported labels are filled with 0.5!

    Returns dict: image_id -> [L] float array
    """
    out = {}
    with open(csv_path, newline="") as f:
        reader = csv.reader(f)
        header = next(reader)
        header_lc = [h.lower().strip() for h in header]

        id_idx = None
        for cand in id_candidates:
            if cand in header_lc:
                id_idx = header_lc.index(cand)
                break
        if id_idx is None:
            return out

        label_idx = []
        for l in label_names:
            label_idx.append(header_lc.index(l.lower()) if l.lower() in header_lc else None)

        for row in reader:
            img_id = row[id_idx].replace(".jpg", "").strip()
            vec = []
            for j in label_idx:
                if j is None:
                    vec.append(0.5)
                else:
                    try:
                        vec.append(float(row[j]))
                    except Exception:
                        vec.append(0.5)
            out[img_id] = np.asarray(vec, dtype=np.float32)
    return out
df = pd.read_csv(REPO_ROOT / 'data' / 'openi' / 'predictions' / 'densenet_mocov2_pt_openi_train.csv')
labels = df.drop(columns=['filename']).columns.tolist()
out = read_predictions_csv(str(REPO_ROOT / 'data' / 'openi' / 'predictions' / 'densenet_mocov2_pt_openi_train.csv'), labels)

In [None]:
from torch.utils.data import Dataset
from PIL import Image

def _fallback_to_tensor(img: Image.Image) -> torch.Tensor:
    arr = np.asarray(img, dtype=np.float32) / 255.0
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    return torch.from_numpy(arr).permute(2, 0, 1)

class OpenIRoutedDataset(Dataset):
    """
    Dataset for the OpenI dataset

    Returns per-sample:
      image        : Tensor[C,H,W]
      gt           : Tensor[L]
      tool_preds   : Tensor[M, L]
      tool_mask    : Tensor[M, L]  (1 = tool valid for task, 0 invalide)
      id           : str
    """

    def __init__(
        self,
        label_csv: str,
        images_dir: str,
        predictions_registry: Dict[str, str],
        label_names: List[str],
        transform=None,
        check_files=False,
    ):
        self.images_dir = images_dir
        self.transform = transform
        self.label_names = label_names
        self.L = len(label_names)

        # --- Load labels ---
        self.records = []
        with open(label_csv, newline="") as f:
            reader = csv.reader(f)
            header = next(reader)
            hmap = {h.strip(): i for i, h in enumerate(header)}
            for row in reader:
                img_id = row[0].strip()
                path = os.path.join(images_dir, f"{img_id}.jpg")
                if check_files and not os.path.exists(path):
                    continue
                gt = [float(row[hmap[l]]) for l in label_names]
                self.records.append({
                    "id": img_id,
                    "path": path,
                    "gt": torch.tensor(gt, dtype=torch.float32)
                })

        # --- Load tool predictions ---
        self.tool_names = sorted(predictions_registry.keys())
        self.M = len(self.tool_names)

        self.tool_preds = []
        for tool in self.tool_names:
            self.tool_preds.append(
                read_predictions_csv(predictions_registry[tool], label_names)
            )

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]

        img = Image.open(rec["path"]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        else:
            img = _fallback_to_tensor(img)


        preds = torch.full((self.M, self.L), 0.5)
        mask  = torch.zeros((self.M, self.L))

        for m, tool_dict in enumerate(self.tool_preds):
            if rec["id"] in tool_dict:
                p = torch.from_numpy(tool_dict[rec["id"]])
                preds[m] = p
                mask[m] = (torch.abs(p - 0.5) > 1e-4).float()

        return {
            "image": img,
            "gt": rec["gt"],
            "tool_preds": preds,
            "tool_mask": mask,
            "id": rec["id"]
        }


tr_demo_dataset = OpenIRoutedDataset(
    label_csv=str(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Train.csv'),
    images_dir=str(REPO_ROOT / 'data' / 'openi' / 'image'),
    predictions_registry=reg['train'],
    label_names=labels,
)


In [12]:
len(tr_demo_dataset)

935

In [13]:
from torch.utils.data import Subset
import random

class ContextManager:
    """
    Manages the few-shot context sets used to describe tools in DySTANce.

    Key idea (from the paper):
    ------------------------
    Tools are NOT identified by IDs or learned embeddings.
    Instead, each tool E is represented only through its behaviour
    on a small, task-specific context set:

        D_E^t = {(x_b, y_b^t, m_E^t(x_b))}_{b=1}^{B_t}

    This class is responsible for:
      1) Constructing these context sets in a leakage-free way
      2) Ensuring context examples are task- and tool-valid
      3) Enforcing a strict separation between:
           - data used to DESCRIBE tools (context)
           - data used to TRAIN the router (routing set)

    """

    def __init__(
        self,
        dataset: OpenIRoutedDataset,
        context_fraction: float = 0.1,
        examples_per_tool: int = 32,
    ):
        """
        Parameters
        ----------
        dataset : OpenIRoutedDataset
            Full TRAINING dataset containing images, ground-truth labels,
            tool predictions, and tool validity masks.

        context_fraction : float
            Fraction of the training data reserved EXCLUSIVELY for
            tool description (context). These samples are never used
            for routing loss computation.

        examples_per_tool : int
            Number of context examples B_t to sample per (tool, task)
            when constructing the ANP summary.
        """

        self.dataset = dataset
        self.examples_per_tool = examples_per_tool

        # ------------------------------------------------------------------
        # 1) Split dataset indices into CONTEXT and ROUTING partitions
        # ------------------------------------------------------------------
        # This enforces the core invariance:
        #   "An image used to describe a tool is never used to train the router."
        #
        # This prevents information leakage and ensures the ANP summaries
        # remain exogenous to the routing objective.
        # ------------------------------------------------------------------
        N = len(dataset)
        perm = torch.randperm(N).tolist()  # random i.i.d. partition
        split = int(context_fraction * N)

        self.context_idx = perm[:split]    # used ONLY for tool descriptors
        self.routing_idx = perm[split:]    # used ONLY for router training

        # ------------------------------------------------------------------
        # 2) Pre-index valid context examples
        # ------------------------------------------------------------------
        # We build a lookup table:
        #
        #   (tool_idx, task_idx) -> [dataset indices]
        #
        # Only examples where:
        #   - the tool actually produced a meaningful prediction
        #   - for the specific task (label)
        #
        # are included.
        #
        # This is critical because many tools emit "0.5" for unsupported
        # labels, which must NOT contaminate the context set.
        # ------------------------------------------------------------------
        self.pool = {}  # maps (tool_idx, task_idx) to list of dataset indices

        for i in self.context_idx:
            item = dataset[i]
            mask = item["tool_mask"]  # shape: [num_tools, num_tasks]

            # Iterate over all (tool, task) pairs and record valid contexts
            for t in range(dataset.M):
                for l in range(dataset.L):
                    if mask[t, l] > 0.5:
                        # This example is informative for tool t on task l
                        self.pool.setdefault((t, l), []).append(i)

    def sample_context(self, tool_idx: int, task_idx: int):
        """
        Samples a few-shot context set D_E^t for a specific tool and task.

        Returns
        -------
        (images, gt_labels, tool_predictions) or None

        images           : Tensor[B, C, H, W]
        gt_labels        : Tensor[B]
        tool_predictions : Tensor[B]

        This tuple corresponds exactly to:
            (x_b, y_b^t, m_E^t(x_b))_{b=1}^{B_t}

        If no valid context exists for (tool, task), returns None.
        This signals that the tool has no observable behavior for this task.
        """

        key = (tool_idx, task_idx)
        candidates = self.pool.get(key, [])

        # If the tool has never produced a valid prediction for this task,
        # we cannot construct a meaningful context descriptor.
        if len(candidates) == 0:
            return None

        # Randomly sample up to B_t context examples (few-shot, exchangeable)
        idxs = random.sample(
            candidates,
            k=min(self.examples_per_tool, len(candidates))
        )

        imgs, gt, preds = [], [], []
        for i in idxs:
            item = self.dataset[i]

            # Each context triple corresponds to:
            #   image x_b
            #   ground-truth label y_b^t
            #   tool prediction m_E^t(x_b)
            imgs.append(item["image"])
            gt.append(item["gt"][task_idx])
            preds.append(item["tool_preds"][tool_idx, task_idx])

        return (
            torch.stack(imgs),
            torch.stack(gt),
            torch.stack(preds),
        )

    def routing_dataset(self):
        """
        Returns the subset of the dataset used for training the router.

        This subset is guaranteed to be disjoint from the context set,
        ensuring no leakage between tool description and routing loss.
        """
        return Subset(self.dataset, self.routing_idx)


In [14]:
tr_demo_ctx_mgr = ContextManager(tr_demo_dataset)

In [15]:
tr_demo_ctx_mgr.routing_dataset()

<torch.utils.data.dataset.Subset at 0x7e8fa6cb94c0>

In [20]:
tr_demo_ctx_mgr.sample_context(0, 0)

(tensor([[[[0.6078, 0.6392, 0.6627,  ..., 0.3176, 0.3294, 0.3804],
           [0.6471, 0.6784, 0.7059,  ..., 0.3059, 0.3333, 0.3804],
           [0.6275, 0.6431, 0.6627,  ..., 0.3137, 0.3255, 0.3725],
           ...,
           [0.2745, 0.2667, 0.2745,  ..., 0.2667, 0.2824, 0.3020],
           [0.2745, 0.2627, 0.2745,  ..., 0.2706, 0.2863, 0.3059],
           [0.2745, 0.2627, 0.2745,  ..., 0.2824, 0.2980, 0.3137]],
 
          [[0.6078, 0.6392, 0.6627,  ..., 0.3176, 0.3294, 0.3804],
           [0.6471, 0.6784, 0.7059,  ..., 0.3059, 0.3333, 0.3804],
           [0.6275, 0.6431, 0.6627,  ..., 0.3137, 0.3255, 0.3725],
           ...,
           [0.2745, 0.2667, 0.2745,  ..., 0.2667, 0.2824, 0.3020],
           [0.2745, 0.2627, 0.2745,  ..., 0.2706, 0.2863, 0.3059],
           [0.2745, 0.2627, 0.2745,  ..., 0.2824, 0.2980, 0.3137]],
 
          [[0.6078, 0.6392, 0.6627,  ..., 0.3176, 0.3294, 0.3804],
           [0.6471, 0.6784, 0.7059,  ..., 0.3059, 0.3333, 0.3804],
           [0.6275, 0.64

In [85]:
# # batch is list of dicts
# batch = next(iter(train_loader))

# # sample a task
# task_idx = torch.randint(0, L, ()).item()

# # slice task-specific view
# preds = batch["tool_preds"][:, :, task_idx]    # [B, M]
# mask  = batch["tool_mask"][:, :, task_idx]     # [B, M]
# gt    = batch["gt"][:, task_idx]                # [B]

# # hard mask invalid tools BEFORE softmax
# router_logits[mask == 0] = -1e9


In [None]:
###Â examples of data loading
from torch.utils.data import DataLoader
import torchvision.transforms as T

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

PREDICTIONS_DIR = str(REPO_ROOT / 'data' / 'openi' / 'predictions')

registry_all = scan_prediction_files(PREDICTIONS_DIR)

# Example tool split (train on DenseNet + EVA-X, test on ResNets)
train_tools = [t for t in registry_all["train"] if "resnet" not in t]
test_tools  = [t for t in registry_all["train"] if "resnet" in t]

train_registry = {t: registry_all["train"][t] for t in train_tools}
val_registry   = {t: registry_all["val"][t]   for t in train_tools}
test_registry  = {t: registry_all["test"][t]  for t in test_tools}


In [None]:
label_names = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]

train_dataset_full = OpenIRoutedDataset(
    label_csv=str(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Train.csv'),
    images_dir=str(REPO_ROOT / 'data' / 'openi' / 'image'),
    predictions_registry=train_registry,
    label_names=label_names,
    transform=transform,
)

val_dataset = OpenIRoutedDataset(
    label_csv=str(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Valid.csv'),
    images_dir=str(REPO_ROOT / 'data' / 'openi' / 'image'),
    predictions_registry=val_registry,
    label_names=label_names,
    transform=transform,
)

test_dataset = OpenIRoutedDataset(
    label_csv=str(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Test.csv'),
    images_dir=str(REPO_ROOT / 'data' / 'openi' / 'image'),
    predictions_registry=test_registry,
    label_names=label_names,
    transform=transform,
)


In [33]:
ctx_mgr = ContextManager(
    dataset=train_dataset_full,
    context_fraction=0.1,        # 10% reserved for context
    examples_per_tool=32,        # B_t
)
train_dataset = ctx_mgr.routing_dataset()


In [34]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


In [36]:
import torch
import random

num_epochs = 1

num_tasks = len(label_names)
num_tools = train_dataset_full.M

for epoch in range(num_epochs):
    for batch in train_loader:

        # -------------------------------------------------
        # 1. Sample a task (label) for this routing step
        # -------------------------------------------------
        task_idx = random.randint(0, num_tasks - 1)

        images = batch["image"]                 # [B, C, H, W]
        gt     = batch["gt"][:, task_idx]       # [B]
        preds  = batch["tool_preds"][:, :, task_idx]  # [B, M]
        mask   = batch["tool_mask"][:, :, task_idx]   # [B, M]

        # -------------------------------------------------
        # 2. Sample context for each tool (ANP input)
        # -------------------------------------------------
        context_per_tool = []

        for tool_idx in range(num_tools):
            ctx = ctx_mgr.sample_context(tool_idx, task_idx)

            if ctx is None:
                context_per_tool.append(None)
            else:
                ctx_imgs, ctx_gt, ctx_preds = ctx
                context_per_tool.append({
                    "images": ctx_imgs,      # [B_t, C, H, W]
                    "gt": ctx_gt,            # [B_t]
                    "preds": ctx_preds,      # [B_t]
                })

        # -------------------------------------------------
        # 3. Forward pass (router + ANP)
        # -------------------------------------------------
        # router_logits = router(images, context_per_tool, task_idx)
        #
        # IMPORTANT:
        # Mask invalid tools BEFORE softmax
        #
        # router_logits[mask == 0] = -1e9

        # -------------------------------------------------
        # 4. Compute comp-sum loss (task-specific)
        # -------------------------------------------------
        # loss = comp_sum_loss(router_logits, preds, gt, mask)
        # loss.backward()
        # optimizer.step()


In [None]:
router.eval()
with torch.no_grad():
    for batch in val_loader:
        task_idx = random.randint(0, num_tasks - 1)

        images = batch["image"]
        gt     = batch["gt"][:, task_idx]
        preds  = batch["tool_preds"][:, :, task_idx]
        mask   = batch["tool_mask"][:, :, task_idx]

        # Sample context exactly as during training
        context_per_tool = [
            ctx_mgr.sample_context(t, task_idx)
            for t in range(num_tools)
        ]

        # router_logits = router(images, context_per_tool, task_idx)
        # router_logits[mask == 0] = -1e9
        # evaluate routing decision
