# 10: Consolidated Evaluation

Compares all synthesis methods against the 21-query SQL benchmark:
1. Wide-table DP-SGD (VAE, epsilon=4.0)
2. Per-table DP histogram synthesis (epsilon=4.0 per table)
3. MST marginal-based baseline (epsilon=4.0 per table)
4. Private Evolution (pending completion)

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import Markdown, display

sys.path.insert(0, str(Path.cwd().parent))
from src.eval.compare import QUERY_METADATA, evaluate_all, results_to_dataframe, detailed_results_to_dataframe

In [None]:
REAL_DIR = Path("../data/results/real")
SYNTH_DIRS = {
    "Wide-table DP-SGD": Path("../data/results/synthetic"),
    "Per-table DP-SGD": Path("../data/results/synth_pertable"),
    "MST baseline": Path("../data/results/synth_mst"),
}

pe_dir = Path("../data/results/synth_pe")
if pe_dir.exists() and any(pe_dir.glob("*.csv")):
    SYNTH_DIRS["Private Evolution"] = pe_dir

display(Markdown(f"Methods to compare: {', '.join(SYNTH_DIRS.keys())}"))

---
## Step 1: Run evaluation for all methods

In [None]:
all_summaries = {}
all_details = {}

for method, synth_dir in SYNTH_DIRS.items():
    results = evaluate_all(REAL_DIR, synth_dir)
    summary = results_to_dataframe(results)
    detail = detailed_results_to_dataframe(results)
    all_summaries[method] = summary
    all_details[method] = detail

    evaluated = summary[summary["error"].isna()]
    n_pass = int(evaluated["passed"].sum())
    n_eval = len(evaluated)
    avg_score = evaluated["score"].mean()
    display(Markdown(
        f"{method}: {n_pass}/{n_eval} queries passed, "
        f"average score = {avg_score:.3f}"
    ))

---
## Step 2: Query pass rates by method

In [None]:
all_queries = sorted(QUERY_METADATA.keys())
infeasible = {
    "ranked_process_classifications",
    "top_10_processes_per_user_id_ranked_by_total_power_consumption",
    "top_20_most_power_consuming_processes_by_avg_power_consumed",
}
feasible_queries = [q for q in all_queries if q not in infeasible]

rows = []
for q in feasible_queries:
    row = {"query": q, "type": QUERY_METADATA[q]["type"]}
    for method, summary in all_summaries.items():
        match = summary[summary["query"] == q]
        if len(match) == 0 or match.iloc[0]["error"] == match.iloc[0]["error"]:  # NaN check
            if len(match) > 0 and pd.isna(match.iloc[0]["error"]):
                row[f"{method}_score"] = match.iloc[0]["score"]
                row[f"{method}_passed"] = bool(match.iloc[0]["passed"])
            else:
                row[f"{method}_score"] = None
                row[f"{method}_passed"] = None
        else:
            row[f"{method}_score"] = None
            row[f"{method}_passed"] = None
    rows.append(row)

comparison = pd.DataFrame(rows)
display(comparison)

---
## Step 3: Overall pass rates

In [None]:
methods = list(SYNTH_DIRS.keys())
pass_rates = []
avg_scores = []
n_evaluated = []

for m in methods:
    s = all_summaries[m]
    ev = s[s["error"].isna()]
    pass_rates.append(int(ev["passed"].sum()))
    avg_scores.append(ev["score"].mean())
    n_evaluated.append(len(ev))

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

colors = ["#e74c3c", "#3498db", "#2ecc71", "#9b59b6"]
x = np.arange(len(methods))

bars = axes[0].bar(x, pass_rates, color=colors[:len(methods)])
axes[0].set_xticks(x)
axes[0].set_xticklabels(methods, rotation=15, ha="right", fontsize=9)
axes[0].set_ylabel("Queries passed")
axes[0].set_title("Queries passing (score >= 0.5)")
for i, (bar, pr, ne) in enumerate(zip(bars, pass_rates, n_evaluated)):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
                 f"{pr}/{ne}", ha="center", fontsize=9)

bars2 = axes[1].bar(x, avg_scores, color=colors[:len(methods)])
axes[1].set_xticks(x)
axes[1].set_xticklabels(methods, rotation=15, ha="right", fontsize=9)
axes[1].set_ylabel("Average score")
axes[1].set_title("Average query score (fraction of metrics passing)")
axes[1].set_ylim(0, 1)
for bar, sc in zip(bars2, avg_scores):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                 f"{sc:.3f}", ha="center", fontsize=9)

plt.tight_layout()
plt.savefig("../report/figures/method_comparison_overall.pdf", bbox_inches="tight")
plt.show()

---
## Step 4: Per-query-type breakdown

In [None]:
query_types = sorted(set(QUERY_METADATA[q]["type"] for q in feasible_queries))

type_scores = {m: [] for m in methods}
for qt in query_types:
    for m in methods:
        col = f"{m}_score"
        subset = comparison[comparison["type"] == qt]
        valid = subset[col].dropna()
        type_scores[m].append(valid.mean() if len(valid) > 0 else 0)

x = np.arange(len(query_types))
width = 0.8 / len(methods)

fig, ax = plt.subplots(figsize=(12, 5))
for i, m in enumerate(methods):
    offset = (i - len(methods)/2 + 0.5) * width
    ax.bar(x + offset, type_scores[m], width, label=m, color=colors[i])

ax.set_xticks(x)
ax.set_xticklabels([t.replace("_", " ") for t in query_types], rotation=20, ha="right", fontsize=9)
ax.set_ylabel("Average score")
ax.set_title("Average query score by query type")
ax.set_ylim(0, 1)
ax.legend(loc="upper right", fontsize=8)
plt.tight_layout()
plt.savefig("../report/figures/method_comparison_by_type.pdf", bbox_inches="tight")
plt.show()

---
## Step 5: Per-query score heatmap

In [None]:
score_cols = [f"{m}_score" for m in methods]
heatmap_data = comparison.set_index("query")[score_cols].copy()
heatmap_data.columns = methods
heatmap_data = heatmap_data.apply(pd.to_numeric, errors="coerce").fillna(-0.1)

fig, ax = plt.subplots(figsize=(8, 10))
im = ax.imshow(heatmap_data.values, aspect="auto", cmap="RdYlGn", vmin=0, vmax=1)

ax.set_xticks(np.arange(len(methods)))
ax.set_xticklabels(methods, fontsize=9)
ax.set_yticks(np.arange(len(heatmap_data)))

short_names = []
for q in heatmap_data.index:
    name = q.replace("_", " ")
    if len(name) > 45:
        name = name[:42] + "..."
    short_names.append(name)
ax.set_yticklabels(short_names, fontsize=7)

for i in range(len(heatmap_data)):
    for j in range(len(methods)):
        val = heatmap_data.values[i, j]
        if val < 0:
            text = "N/A"
            color = "gray"
        else:
            text = f"{val:.2f}"
            color = "white" if val < 0.3 else "black"
        ax.text(j, i, text, ha="center", va="center", fontsize=7, color=color)

fig.colorbar(im, ax=ax, label="Query score", shrink=0.6)
ax.set_title("Per-query scores across methods")
plt.tight_layout()
plt.savefig("../report/figures/method_comparison_heatmap.pdf", bbox_inches="tight")
plt.show()

---
## Step 6: Detailed metric comparison for key queries

In [None]:
key_queries = [
    "avg_platform_power_c0_freq_temp_by_chassis",
    "most_popular_browser_in_each_country_by_system_count",
    "ram_utilization_histogram",
    "battery_power_on_geographic_summary",
    "popular_browsers_by_count_usage_percentage",
]

for q in key_queries:
    display(Markdown(f"### {q.replace('_', ' ').title()}"))
    rows = []
    for m in methods:
        detail = all_details[m]
        q_detail = detail[detail["query"] == q]
        for _, r in q_detail.iterrows():
            rows.append({
                "method": m,
                "column": r["column"],
                "metric": r["metric_type"],
                "value": f"{r['value']:.4f}" if pd.notna(r['value']) and r['value'] != float('inf') else "inf",
                "passed": r["passed"],
                "detail": r["detail"],
            })
    if rows:
        display(pd.DataFrame(rows))
    else:
        display(Markdown("No data available for this query."))
    print()

---
## Step 7: Summary table for report

In [None]:
report_rows = []
for m in methods:
    s = all_summaries[m]
    ev = s[s["error"].isna()]
    report_rows.append({
        "Method": m,
        "Queries evaluated": len(ev),
        "Queries passed": int(ev["passed"].sum()),
        "Pass rate": f"{ev['passed'].mean():.1%}",
        "Avg score": f"{ev['score'].mean():.3f}",
        "Median score": f"{ev['score'].median():.3f}",
    })

report_df = pd.DataFrame(report_rows)
display(Markdown("### Method comparison summary"))
display(report_df)

display(Markdown("### LaTeX table"))
latex = report_df.to_latex(index=False, escape=True)
print(latex)

---
## Step 8: Strengths and weaknesses analysis

In [None]:
display(Markdown("### Queries where methods diverge"))

for q in feasible_queries:
    scores = {}
    for m in methods:
        col = f"{m}_score"
        val = comparison.loc[comparison["query"] == q, col].values
        if len(val) > 0 and pd.notna(val[0]):
            scores[m] = val[0]
    
    if len(scores) >= 2:
        vals = list(scores.values())
        spread = max(vals) - min(vals)
        if spread >= 0.3:
            best = max(scores, key=scores.get)
            worst = min(scores, key=scores.get)
            display(Markdown(
                f"  {q}: spread={spread:.2f}, "
                f"best={best} ({scores[best]:.2f}), "
                f"worst={worst} ({scores[worst]:.2f})"
            ))

In [None]:
display(Markdown("### Queries all methods fail"))

all_fail = []
for q in feasible_queries:
    all_scores = []
    for m in methods:
        col = f"{m}_score"
        val = comparison.loc[comparison["query"] == q, col].values
        if len(val) > 0 and pd.notna(val[0]):
            all_scores.append(val[0])
    if all_scores and max(all_scores) < 0.5:
        all_fail.append(q)
        
for q in all_fail:
    scores_str = ", ".join(
        f"{m}: {comparison.loc[comparison['query']==q, f'{m}_score'].values[0]:.2f}"
        for m in methods
        if pd.notna(comparison.loc[comparison['query']==q, f'{m}_score'].values[0])
    )
    display(Markdown(f"  {q} ({scores_str})"))

display(Markdown(f"\nTotal: {len(all_fail)}/{len(feasible_queries)} queries fail all methods"))