In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

from dapinet.analysis import (
    load_datasets,
    run_model_inference,
)

## Configuration

In [2]:
# Define variants to evaluate
# Ensure these directories exist in 'models/'
VARIANTS = [
    "DAPINet",  # Base model
    "ablation_convex",  # Trained on Convex
    "ablation_manifold",  # Trained on Manifold
]

VARIANT_LABELS = {
    "DAPINet": "Base (Full)",
    "ablation_convex": "Trained on Convex",
    "ablation_manifold": "Trained on Manifold",
}

# Paths
MODELS_DIR = Path("models")
RESULTS_DIR = Path("results/dataset_ablation")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

REAL_WORLD_DIR = Path("datasets/real_world")
ORACLE_PATH = Path("results/oracle_ari.csv")

## Load Data & Run Inference

In [4]:
# Load real-world datasets
print(f"Loading real-world datasets from: {REAL_WORLD_DIR}")
real_datasets = load_datasets(REAL_WORLD_DIR)

2026-01-28 09:51:33,694 INFO: Loading datasets from datasets\real_world


Loading real-world datasets from: datasets\real_world


Loading Datasets: 100%|██████████| 20/20 [00:00<00:00, 513.94it/s]
2026-01-28 09:51:33,744 INFO: Loaded 20 datasets.


In [5]:
# Run inference for each variant
results = {}

for variant in VARIANTS:
    model_path = MODELS_DIR / variant

    if not model_path.exists():
        print(f"⚠️  Model directory not found: {model_path}")
        print(f"    Skipping {variant}...\n")
        continue

    print(f"\n{'=' * 60}")
    print(f"Evaluating: {VARIANT_LABELS[variant]}")
    print(f"{'=' * 60}")

    df_pred, stats = run_model_inference(model_path, real_datasets)

    # Save predictions
    output_file = RESULTS_DIR / f"real_world_{variant}.csv"
    df_pred.to_csv(output_file, index=False)
    print(f"Saved results to: {output_file}")

    results[variant] = df_pred

2026-01-28 09:51:38,915 INFO: Found 5 models in models\DAPINet



Evaluating: Base (Full)


2026-01-28 09:51:39,158 INFO: Updating Config from checkpoint_fold_1.pth...
2026-01-28 09:51:39,187 INFO: Loaded model from models\DAPINet\checkpoint_fold_1.pth (Epoch 4, Loss 0.0456)
2026-01-28 09:51:39,302 INFO: Updating Config from checkpoint_fold_2.pth...
2026-01-28 09:51:39,322 INFO: Loaded model from models\DAPINet\checkpoint_fold_2.pth (Epoch 8, Loss 0.0461)
2026-01-28 09:51:39,386 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:51:39,408 INFO: Loaded model from models\DAPINet\checkpoint_fold_3.pth (Epoch 9, Loss 0.0513)
2026-01-28 09:51:39,468 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:51:39,486 INFO: Loaded model from models\DAPINet\checkpoint_fold_4.pth (Epoch 7, Loss 0.0482)
2026-01-28 09:51:39,551 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:51:39,570 INFO: Loaded model from models\DAPINet\checkpoint_fold_5.pth (Epoch 8, Loss 0.0447)
Model Inference: 100%|██████████| 20/20 [00:02<00:00,  8.74it/s]
2026-01-28 09:5

Saved results to: results\dataset_ablation\real_world_DAPINet.csv

Evaluating: Trained on Convex


2026-01-28 09:51:42,074 INFO: Updating Config from checkpoint_fold_2.pth...
2026-01-28 09:51:42,097 INFO: Loaded model from models\ablation_convex\checkpoint_fold_2.pth (Epoch 6, Loss 0.0526)
2026-01-28 09:51:42,198 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:51:42,214 INFO: Loaded model from models\ablation_convex\checkpoint_fold_3.pth (Epoch 5, Loss 0.0585)
2026-01-28 09:51:42,301 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:51:42,318 INFO: Loaded model from models\ablation_convex\checkpoint_fold_4.pth (Epoch 9, Loss 0.0505)
2026-01-28 09:51:42,392 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:51:42,409 INFO: Loaded model from models\ablation_convex\checkpoint_fold_5.pth (Epoch 9, Loss 0.0466)
Model Inference: 100%|██████████| 20/20 [00:01<00:00, 17.11it/s]
2026-01-28 09:51:43,589 INFO: Found 5 models in models\ablation_manifold
2026-01-28 09:51:43,670 INFO: Updating Config from checkpoint_fold_1.pth...
2026-01-28 09:51:4

Saved results to: results\dataset_ablation\real_world_ablation_convex.csv

Evaluating: Trained on Manifold


2026-01-28 09:51:43,870 INFO: Updating Config from checkpoint_fold_3.pth...
2026-01-28 09:51:43,888 INFO: Loaded model from models\ablation_manifold\checkpoint_fold_3.pth (Epoch 7, Loss 0.0273)
2026-01-28 09:51:43,978 INFO: Updating Config from checkpoint_fold_4.pth...
2026-01-28 09:51:43,996 INFO: Loaded model from models\ablation_manifold\checkpoint_fold_4.pth (Epoch 7, Loss 0.0269)
2026-01-28 09:51:44,096 INFO: Updating Config from checkpoint_fold_5.pth...
2026-01-28 09:51:44,122 INFO: Loaded model from models\ablation_manifold\checkpoint_fold_5.pth (Epoch 6, Loss 0.0278)
Model Inference: 100%|██████████| 20/20 [00:01<00:00, 18.36it/s]

Saved results to: results\dataset_ablation\real_world_ablation_manifold.csv





## Comparative Analysis

In [7]:
if not ORACLE_PATH.exists():
    print(f"⚠️ Oracle file not found at {ORACLE_PATH}. Cannot compute regret.")
    oracle_df = None
else:
    oracle_df = pd.read_csv(ORACLE_PATH).set_index("dataset")

summary_stats = []
all_regrets = []

if oracle_df is not None:
    for variant, df_pred in results.items():
        algo_cols = [
            c
            for c in df_pred.columns
            if c not in {"dataset", "inference_time_ms"}
        ]
        per_dataset = []

        for _, row in df_pred.iterrows():
            ds_name = row["dataset"]
            if ds_name not in oracle_df.index:
                continue

            # Get predicted algorithm and its true ARI
            probs = row[algo_cols].to_numpy(dtype=float)
            pred_idx = int(np.argmax(probs))
            pred_algo = algo_cols[pred_idx]

            true_ari_row = oracle_df.loc[ds_name, algo_cols]
            pred_ari = float(true_ari_row[pred_algo])
            max_true_ari = float(true_ari_row.max())
            regret = max_true_ari - pred_ari

            per_dataset.append(
                {
                    "dataset": ds_name,
                    "pred_ari": pred_ari,
                    "max_true_ari": max_true_ari,
                    "regret": regret,
                }
            )

            all_regrets.append({"Variant": VARIANT_LABELS[variant], "Regret": regret})

        ds_df = pd.DataFrame(per_dataset)

        summary_stats.append(
            {
                "Variant": VARIANT_LABELS[variant],
                "Mean Pred ARI": ds_df["pred_ari"].mean(),
                "Median Pred ARI": ds_df["pred_ari"].median(),
                "Mean Regret": ds_df["regret"].mean(),
                "Median Regret": ds_df["regret"].median(),
                "Near Optimal (<10% Regret)": (ds_df["regret"] <= 0.10).mean() * 100,
            }
        )

    summary_df = pd.DataFrame(summary_stats)
    print("\nSummary Statistics on Real-World Datasets:")
    print(summary_df.to_string(index=False))


Summary Statistics on Real-World Datasets:
            Variant  Mean Pred ARI  Median Pred ARI  Mean Regret  Median Regret  Near Optimal (<10% Regret)
        Base (Full)       0.352790         0.245972     0.048809       0.005170                        90.0
  Trained on Convex       0.288083         0.201969     0.113516       0.057164                        70.0
Trained on Manifold       0.228758         0.165580     0.172840       0.144153                        50.0
