In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
import imports

Using device: cuda


In [12]:
# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
DATA_ROOT = "/home/kell6630/repos/DySTANce//data/openi"
LABELS_DIR = f"{DATA_ROOT}/labels"
IMAGES_DIR = f"{DATA_ROOT}/image"
PRED_DIR   = f"{DATA_ROOT}/predictions"

# ------------------------------------------------------------
# Labels (tasks)
# ------------------------------------------------------------
label_names = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]
num_tasks = len(label_names)

# ------------------------------------------------------------
# Tool registry
# ------------------------------------------------------------
registry_all = imports.scan_prediction_files(PRED_DIR)

# Example split: train on non-resnet tools
train_tools = [t for t in registry_all["train"] if "resnet" not in t]

train_registry = {t: registry_all["train"][t] for t in train_tools}
val_registry   = {t: registry_all["val"][t]   for t in train_tools}

# ------------------------------------------------------------
# Datasets
# ------------------------------------------------------------
train_dataset_full = imports.OpenIRoutedDataset(
    label_csv=f"{LABELS_DIR}/Train.csv",
    images_dir=IMAGES_DIR,
    predictions_registry=train_registry,
    label_names=label_names,
    transform=None,  # assume tensor conversion inside dataset
)

val_dataset = imports.OpenIRoutedDataset(
    label_csv=f"{LABELS_DIR}/Valid.csv",
    images_dir=IMAGES_DIR,
    predictions_registry=val_registry,
    label_names=label_names,
    transform=None,
)



In [13]:
ctx_mgr = imports.ContextManager(
    dataset=train_dataset_full,
    context_fraction=0.1,      # 10% context
    examples_per_tool=32,      # B_t
)

train_dataset = ctx_mgr.routing_dataset()


In [14]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


In [24]:
model = imports.DySTANceRouter(
    num_tasks=num_tasks,
    vocab_size=1000,   # dummy vocab size for now
    hidden_dim=256,
).to(device)

criterion = imports.DySTANceLoss(
    surrogate_type="logistic",
    lambda_entropy=0.05,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4,
)


def build_context_tensors(ctx_mgr, task_idx, device):
    """
    Builds task-conditional context tensors for all tools.

    Returns:
        ctx_img_feat : [M, C, Dx] on `device`
        ctx_gt       : [M, C]     on `device`
        ctx_pred     : [M, C]     on `device`
    """
    ctx_img_feats = []
    ctx_gts = []
    ctx_preds = []

    M = ctx_mgr.dataset.M
    C = ctx_mgr.examples_per_tool

    for tool_idx in range(M):
        ctx = ctx_mgr.sample_context(tool_idx, task_idx)

        if ctx is None:
            # No valid context for this tool-task pair
            ctx_img_feats.append(
                torch.zeros(C, model.img_dim, device=device)
            )
            ctx_gts.append(
                torch.zeros(C, device=device)
            )
            ctx_preds.append(
                torch.zeros(C, device=device)
            )
        else:
            imgs, gt, preds = ctx

            imgs = imgs.to(device)
            gt = gt.to(device)
            preds = preds.to(device)

            with torch.no_grad():
                feats = model.extract_img_feat(imgs)  # [C, Dx]

            ctx_img_feats.append(feats)
            ctx_gts.append(gt)
            ctx_preds.append(preds)

    return (
        torch.stack(ctx_img_feats, dim=0),  # [M, C, Dx]
        torch.stack(ctx_gts, dim=0),        # [M, C]
        torch.stack(ctx_preds, dim=0),      # [M, C]
    )




In [25]:
def train_one_epoch(model, loader, ctx_mgr, optimizer, criterion):
    model.train()

    total_loss = 0.0
    total_batches = 0

    for batch in tqdm(loader, desc="Training"):
        images = batch["image"].to(device)        # [B, 3, H, W]
        gt_all = batch["gt"].to(device)           # [B, L]
        preds_all = batch["tool_preds"].to(device)  # [B, M, L]
        mask_all = batch["tool_mask"].to(device)    # [B, M, L]

        B = images.size(0)

        # ------------------------------------------------------------
        # 1) Sample a task uniformly
        # ------------------------------------------------------------
        task_idx = random.randint(0, num_tasks - 1)
        task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

        # Task-conditional slices
        gt = gt_all[:, task_idx]                 # [B]
        tool_preds = preds_all[:, :, task_idx]   # [B, M]
        tool_mask  = mask_all[:, :, task_idx]    # [B, M]

        # ------------------------------------------------------------
        # 2) Build context for this task
        # ------------------------------------------------------------
        ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
            ctx_mgr, task_idx, device
        )

        # ------------------------------------------------------------
        # 3) Forward pass
        # ------------------------------------------------------------
        scores = model(
            images=images,
            text_tokens=torch.zeros((B, 1), dtype=torch.long, device=device),  # dummy text
            task_idx=task_ids,
            tool_preds=tool_preds,
            ctx_img_feat=ctx_img_feat,
            ctx_gt=ctx_gt,
            ctx_pred=ctx_pred,
            tool_mask=tool_mask,
        )

        # ------------------------------------------------------------
        # 4) Compute costs (classification task)
        # c_E = 1 - confidence on true label
        # ------------------------------------------------------------
        tool_costs = 1.0 - tool_preds  # [B, M]

        # ------------------------------------------------------------
        # 5) Loss + backward
        # ------------------------------------------------------------
        loss, logs = criterion(scores, tool_costs, tool_mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_batches += 1

    return total_loss / max(1, total_batches)


In [26]:
@torch.no_grad()
def evaluate(model, loader, ctx_mgr):
    model.eval()

    total_regret = 0.0
    total_samples = 0

    for batch in tqdm(loader, desc="Validation"):
        images = batch["image"].to(device)
        gt_all = batch["gt"].to(device)
        preds_all = batch["tool_preds"].to(device)
        mask_all = batch["tool_mask"].to(device)

        B = images.size(0)

        task_idx = random.randint(0, num_tasks - 1)
        task_ids = torch.full((B,), task_idx, device=device, dtype=torch.long)

        tool_preds = preds_all[:, :, task_idx]
        tool_mask  = mask_all[:, :, task_idx]

        ctx_img_feat, ctx_gt, ctx_pred = build_context_tensors(
            ctx_mgr, task_idx, device
        )

        scores = model(
            images,
            torch.zeros((B, 1), dtype=torch.long, device=device),
            task_ids,
            tool_preds,
            ctx_img_feat,
            ctx_gt,
            ctx_pred,
            tool_mask,
        )

        # Routing decision
        chosen = scores.argmax(dim=1)  # [B]

        # Costs
        costs = 1.0 - tool_preds
        chosen_cost = costs[torch.arange(B), chosen]

        oracle_cost = costs.masked_fill(tool_mask == 0, 1e9).min(dim=1).values

        regret = (chosen_cost - oracle_cost).sum().item()
        total_regret += regret
        total_samples += B

    return total_regret / max(1, total_samples)


In [28]:
num_epochs = 100

for epoch in range(num_epochs):
    train_loss = train_one_epoch(
        model, train_loader, ctx_mgr, optimizer, criterion
    )

    val_regret = evaluate(model, val_loader, ctx_mgr)

    print(
        f"[Epoch {epoch:02d}] "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Regret: {val_regret:.4f}"
    )


Training: 100%|██████████| 53/53 [00:12<00:00,  4.30it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.45it/s]


[Epoch 00] Train Loss: -913.2017 | Val Regret: 0.2824


Training: 100%|██████████| 53/53 [00:12<00:00,  4.22it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.42it/s]


[Epoch 01] Train Loss: -879.4879 | Val Regret: 0.2499


Training: 100%|██████████| 53/53 [00:11<00:00,  4.52it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.43it/s]


[Epoch 02] Train Loss: -838.0365 | Val Regret: 0.2382


Training: 100%|██████████| 53/53 [00:13<00:00,  4.04it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.06it/s]


[Epoch 03] Train Loss: -980.4909 | Val Regret: 0.2547


Training: 100%|██████████| 53/53 [00:12<00:00,  4.24it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.49it/s]


[Epoch 04] Train Loss: -917.4980 | Val Regret: 0.1943


Training: 100%|██████████| 53/53 [00:11<00:00,  4.57it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  3.87it/s]


[Epoch 05] Train Loss: -853.4120 | Val Regret: 0.2440


Training: 100%|██████████| 53/53 [00:11<00:00,  4.82it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.82it/s]


[Epoch 06] Train Loss: -846.7866 | Val Regret: 0.2779


Training: 100%|██████████| 53/53 [00:11<00:00,  4.46it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  5.03it/s]


[Epoch 07] Train Loss: -873.3522 | Val Regret: 0.2120


Training: 100%|██████████| 53/53 [00:11<00:00,  4.43it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.34it/s]


[Epoch 08] Train Loss: -865.8953 | Val Regret: 0.3530


Training: 100%|██████████| 53/53 [00:12<00:00,  4.09it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.24it/s]


[Epoch 09] Train Loss: -971.5137 | Val Regret: 0.3454


Training: 100%|██████████| 53/53 [00:12<00:00,  4.16it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.89it/s]


[Epoch 10] Train Loss: -896.7213 | Val Regret: 0.3619


Training: 100%|██████████| 53/53 [00:12<00:00,  4.26it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s]


[Epoch 11] Train Loss: -851.3491 | Val Regret: 0.3524


Training: 100%|██████████| 53/53 [00:12<00:00,  4.41it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  3.96it/s]


[Epoch 12] Train Loss: -866.1831 | Val Regret: 0.2070


Training: 100%|██████████| 53/53 [00:12<00:00,  4.37it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  5.17it/s]


[Epoch 13] Train Loss: -892.8921 | Val Regret: 0.2081


Training: 100%|██████████| 53/53 [00:11<00:00,  4.42it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.06it/s]


[Epoch 14] Train Loss: -921.3738 | Val Regret: 0.2205


Training: 100%|██████████| 53/53 [00:12<00:00,  4.15it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  3.67it/s]


[Epoch 15] Train Loss: -932.6384 | Val Regret: 0.1525


Training: 100%|██████████| 53/53 [00:12<00:00,  4.30it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.15it/s]


[Epoch 16] Train Loss: -934.9215 | Val Regret: 0.1769


Training: 100%|██████████| 53/53 [00:12<00:00,  4.18it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.18it/s]


[Epoch 17] Train Loss: -923.2560 | Val Regret: 0.2111


Training: 100%|██████████| 53/53 [00:13<00:00,  3.96it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.87it/s]


[Epoch 18] Train Loss: -979.3263 | Val Regret: 0.1537


Training: 100%|██████████| 53/53 [00:12<00:00,  4.29it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.37it/s]


[Epoch 19] Train Loss: -910.6538 | Val Regret: 0.2841


Training: 100%|██████████| 53/53 [00:11<00:00,  4.45it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.62it/s]


[Epoch 20] Train Loss: -884.3328 | Val Regret: 0.1521


Training: 100%|██████████| 53/53 [00:12<00:00,  4.37it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  5.13it/s]


[Epoch 21] Train Loss: -891.2880 | Val Regret: 0.2262


Training: 100%|██████████| 53/53 [00:12<00:00,  4.29it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.06it/s]


[Epoch 22] Train Loss: -942.6681 | Val Regret: 0.2312


Training: 100%|██████████| 53/53 [00:12<00:00,  4.29it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.78it/s]


[Epoch 23] Train Loss: -928.8299 | Val Regret: 0.2676


Training: 100%|██████████| 53/53 [00:12<00:00,  4.21it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.74it/s]


[Epoch 24] Train Loss: -938.7985 | Val Regret: 0.2772


Training: 100%|██████████| 53/53 [00:11<00:00,  4.70it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.47it/s]


[Epoch 25] Train Loss: -816.2926 | Val Regret: 0.3574


Training: 100%|██████████| 53/53 [00:11<00:00,  4.45it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.71it/s]


[Epoch 26] Train Loss: -827.5361 | Val Regret: 0.3090


Training: 100%|██████████| 53/53 [00:12<00:00,  4.15it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.32it/s]


[Epoch 27] Train Loss: -919.7971 | Val Regret: 0.2692


Training: 100%|██████████| 53/53 [00:11<00:00,  4.44it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.08it/s]


[Epoch 28] Train Loss: -848.9411 | Val Regret: 0.2265


Training: 100%|██████████| 53/53 [00:12<00:00,  4.25it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  3.99it/s]


[Epoch 29] Train Loss: -932.5237 | Val Regret: 0.3048


Training: 100%|██████████| 53/53 [00:11<00:00,  4.63it/s]
Validation: 100%|██████████| 17/17 [00:04<00:00,  4.01it/s]


[Epoch 30] Train Loss: -823.0023 | Val Regret: 0.2792


Training: 100%|██████████| 53/53 [00:11<00:00,  4.69it/s]
Validation: 100%|██████████| 17/17 [00:03<00:00,  4.42it/s]


[Epoch 31] Train Loss: -840.5516 | Val Regret: 0.2751


Training: 100%|██████████| 53/53 [00:11<00:00,  4.44it/s]
Validation:  18%|█▊        | 3/17 [00:01<00:04,  2.94it/s]


KeyboardInterrupt: 