# Baseline Accuracy Table

MODEL Ã— DATASET accuracy matrix

In [1]:
import pandas as pd
from data_reader.json_reader import get_baseline_stats
from constants import MODELS, DATASETS, MEDAGENTS_DATASETS, MED_QA_DATASET

In [None]:
results = []
row_counts = []

for model in MODELS:
    row = {"Model": model}
    counts = {"Model": model}
    
    for dataset in DATASETS:
        # med_qa: split=None combines all splits automatically
        # others: split=None defaults to "test"
        stats = get_baseline_stats(dataset_name=dataset, model=model, split=None)
        
        if stats:
            row[dataset] = stats["accuracy"]
            counts[dataset] = stats["n_total"]
        else:
            row[dataset] = None
            counts[dataset] = 0
    
    results.append(row)
    row_counts.append(counts)
    print(f"Done: {model}")

baseline_df = pd.DataFrame(results)
counts_df = pd.DataFrame(row_counts)

In [None]:
from med_edge_analysis.constants import DATASETS

# Validate: all models must have same row count per dataset
print("="*60)
print("ROW COUNT VALIDATION")
print("="*60)

errors = []
for dataset in DATASETS:
    unique_counts = counts_df[dataset].unique()
    if len(unique_counts) > 1:
        errors.append(f"{dataset}: MISMATCH {dict(zip(counts_df['Model'], counts_df[dataset]))}")
    else:
        print(f"  {dataset}: {unique_counts[0]} rows")

if errors:
    print("\nERRORS FOUND:")
    for e in errors:
        print(f"  {e}")
    raise ValueError("Row count mismatch between models! Fix data before proceeding.")
else:
    print("\nAll datasets have consistent row counts across models.")

In [None]:
# Format and display
display_df = baseline_df.copy()
display_df = display_df.set_index("Model")

# Format as percentages
styled = display_df.style.format("{:.1%}", na_rep="-").background_gradient(cmap="RdYlGn", axis=None)
styled