In [4]:
from pathlib import Path

import numpy as np
import pandas as pd

from dapinet.analysis import (
    load_datasets,
    run_model_inference,
)

## Configuration

In [2]:
# Define variants to evaluate
VARIANTS = [
    "DAPINet",  # Base model
    "ablation_dace_simple",
    "ablation_pool_mean",
    "ablation_no_col_attn",
]

VARIANT_LABELS = {
    "DAPINet": "Base (Full)",
    "ablation_dace_simple": "DACE→Simple Stats",
    "ablation_pool_mean": "PMA→Mean Pool",
    "ablation_no_col_attn": "No Col. Attn.",
}

# Paths
MODELS_DIR = Path("models")
RESULTS_DIR = Path("results/ablation")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Data paths
REAL_WORLD_DIR = Path("datasets/real_world")

In [3]:
# Load real-world datasets
print(f"Loading real-world datasets from: {REAL_WORLD_DIR}")
real_datasets = load_datasets(REAL_WORLD_DIR)

2026-01-28 09:44:40,333 INFO: Loading datasets from datasets\real_world


Loading real-world datasets from: datasets\real_world


Loading Datasets: 100%|██████████| 20/20 [00:00<00:00, 593.82it/s]
2026-01-28 09:44:40,375 INFO: Loaded 20 datasets.


In [5]:
# Run inference for each variant
real_results = {}

for variant in VARIANTS:
    model_path = MODELS_DIR / variant

    if not model_path.exists():
        print(f"⚠️  Skipping {variant} (model not found)\n")
        continue

    print(f"\n{'=' * 60}")
    print(f"Evaluating on Real-World: {VARIANT_LABELS[variant]}")
    print(f"{'=' * 60}")

    df_pred, stats = run_model_inference(model_path, real_datasets)

    # Save results
    output_file = RESULTS_DIR / f"real_world_{variant}.csv"
    df_pred.to_csv(output_file, index=False)

    print(f"  Saved to: {output_file}")

    real_results[variant] = df_pred

2026-01-28 09:44:43,978 INFO: Found 5 models in models\DAPINet



Evaluating on Real-World: Base (Full)


2026-01-28 09:44:44,204 INFO: Updating Config from checkpoint_fold_1.pth...
2026-01-28 09:44:44,231 INFO: Loaded model from models\DAPINet\checkpoint_fold_1.pth (Epoch 4, Loss 0.0456)
2026-01-28 09:44:44,336 INFO: Updating Config from checkpoint_fold_2.pth...
2026-01-28 09:44:44,357 INFO: Loaded model from models\DAPINet\checkpoint_fold_2.pth (Epoch 8, Loss 0.0461)
2026-01-28 09:44:44,442 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:44:44,460 INFO: Loaded model from models\DAPINet\checkpoint_fold_3.pth (Epoch 9, Loss 0.0513)
2026-01-28 09:44:44,553 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:44:44,572 INFO: Loaded model from models\DAPINet\checkpoint_fold_4.pth (Epoch 7, Loss 0.0482)
2026-01-28 09:44:44,630 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:44:44,644 INFO: Loaded model from models\DAPINet\checkpoint_fold_5.pth (Epoch 8, Loss 0.0447)
Model Inference: 100%|██████████| 20/20 [00:02<00:00,  9.63it/s]
2026-01-28 09:4

  Saved to: results\ablation\real_world_DAPINet.csv

Evaluating on Real-World: DACE→Simple Stats


2026-01-28 09:44:46,952 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:44:46,969 INFO: Loaded model from models\ablation_dace_simple\checkpoint_fold_3.pth (Epoch 8, Loss 0.0572)
2026-01-28 09:44:47,035 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:44:47,052 INFO: Loaded model from models\ablation_dace_simple\checkpoint_fold_4.pth (Epoch 9, Loss 0.0584)
2026-01-28 09:44:47,108 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:44:47,125 INFO: Loaded model from models\ablation_dace_simple\checkpoint_fold_5.pth (Epoch 9, Loss 0.0572)
Model Inference: 100%|██████████| 20/20 [00:01<00:00, 16.62it/s]
2026-01-28 09:44:48,338 INFO: Found 5 models in models\ablation_pool_mean
2026-01-28 09:44:48,392 INFO: Updating Config from checkpoint_fold_1.pth...
2026-01-28 09:44:48,411 INFO: Loaded model from models\ablation_pool_mean\checkpoint_fold_1.pth (Epoch 9, Loss 0.0426)
2026-01-28 09:44:48,477 INFO: Updating Config from checkpoint_fold_2.pth...

  Saved to: results\ablation\real_world_ablation_dace_simple.csv

Evaluating on Real-World: PMA→Mean Pool


2026-01-28 09:44:48,559 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:44:48,585 INFO: Loaded model from models\ablation_pool_mean\checkpoint_fold_3.pth (Epoch 8, Loss 0.0463)
2026-01-28 09:44:48,641 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:44:48,663 INFO: Loaded model from models\ablation_pool_mean\checkpoint_fold_4.pth (Epoch 7, Loss 0.0506)
2026-01-28 09:44:48,716 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:44:48,733 INFO: Loaded model from models\ablation_pool_mean\checkpoint_fold_5.pth (Epoch 9, Loss 0.0448)
Model Inference: 100%|██████████| 20/20 [00:01<00:00, 19.17it/s]
2026-01-28 09:44:49,787 INFO: Found 5 models in models\ablation_no_col_attn
2026-01-28 09:44:49,835 INFO: Updating Config from checkpoint_fold_1.pth...
2026-01-28 09:44:49,849 INFO: Loaded model from models\ablation_no_col_attn\checkpoint_fold_1.pth (Epoch 9, Loss 0.0603)
2026-01-28 09:44:49,903 INFO: Updating Config from checkpoint_fold_2.pth...
2

  Saved to: results\ablation\real_world_ablation_pool_mean.csv

Evaluating on Real-World: No Col. Attn.


2026-01-28 09:44:50,004 INFO: Loaded model from models\ablation_no_col_attn\checkpoint_fold_3.pth (Epoch 7, Loss 0.0579)
2026-01-28 09:44:50,084 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:44:50,108 INFO: Loaded model from models\ablation_no_col_attn\checkpoint_fold_4.pth (Epoch 6, Loss 0.0603)
2026-01-28 09:44:50,165 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:44:50,178 INFO: Loaded model from models\ablation_no_col_attn\checkpoint_fold_5.pth (Epoch 6, Loss 0.0574)
Model Inference: 100%|██████████| 20/20 [00:00<00:00, 26.05it/s]

  Saved to: results\ablation\real_world_ablation_no_col_attn.csv





In [8]:
# Real-world ARI analysis (per-variant)
oracle_path = "results/oracle_ari.csv"
oracle_df = pd.read_csv(oracle_path).set_index("dataset")

realworld_variant_stats = []

for variant, df_pred in real_results.items():
    if df_pred is None or df_pred.empty:
        print(f"⚠️  Skipping {VARIANT_LABELS[variant]} (no predictions)")
        continue

    algo_cols = [
        c
        for c in df_pred.columns
        if c not in {"dataset", "inference_time_ms"}
    ]
    per_dataset = []

    for _, row in df_pred.iterrows():
        ds_name = row["dataset"]
        if ds_name not in oracle_df.index:
            continue

        probs = row[algo_cols].to_numpy(dtype=float)
        pred_idx = int(np.argmax(probs))
        pred_algo = algo_cols[pred_idx]
        valid_cols = [c for c in algo_cols if c in oracle_df.columns]
        if not valid_cols:
            continue

        true_ari_row = oracle_df.loc[ds_name, valid_cols]
        pred_ari = float(true_ari_row[pred_algo])
        max_true_ari = float(true_ari_row.max())
        regret = max_true_ari - pred_ari

        per_dataset.append(
            {
                "dataset": ds_name,
                "pred_algo": pred_algo,
                "pred_ari": pred_ari,
                "max_true_ari": max_true_ari,
                "regret": regret,
            }
        )

    if not per_dataset:
        print(f"⚠️  Skipping {VARIANT_LABELS[variant]} (no matching oracle ARIs)")
        continue

    ds_df = pd.DataFrame(per_dataset)

    realworld_variant_stats.append(
        {
            "Variant": VARIANT_LABELS[variant],
            "Mean Pred ARI": ds_df["pred_ari"].mean(),
            "Median Pred ARI": ds_df["pred_ari"].median(),
            "Mean Regret": ds_df["regret"].mean(),
            "Median Regret": ds_df["regret"].median(),
            "Within 10% of Max %": (ds_df["regret"] <= 0.1).mean() * 100,
        }
    )

    # Save per-dataset stats for inspection
    ds_df.to_csv(RESULTS_DIR / f"real_world_ari_{variant}.csv", index=False)

realworld_summary_df = pd.DataFrame(realworld_variant_stats)
realworld_summary_df.to_csv(RESULTS_DIR / "real_world_ari_summary.csv", index=False)

print("\n" + "=" * 60)
print("REAL-WORLD ARI SUMMARY (by predicted top algo)")
print("=" * 60)
print(realworld_summary_df.to_string(index=False))
print("=" * 60)



REAL-WORLD ARI SUMMARY (by predicted top algo)
          Variant  Mean Pred ARI  Median Pred ARI  Mean Regret  Median Regret  Within 10% of Max %
      Base (Full)       0.352790         0.245972     0.048809       0.005170                 90.0
DACE→Simple Stats       0.280008         0.187040     0.121591       0.057294                 65.0
    PMA→Mean Pool       0.302165         0.189768     0.099433       0.017015                 75.0
    No Col. Attn.       0.316549         0.245972     0.085049       0.017015                 70.0
