In [1]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, cvxpy as cp
from datasets import load_dataset
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
MODELS = ['gemini-2.0-flash-grounding', 'gemini-2.5-flash-preview-04-17-grounding', 'gemini-2.5-pro-exp-03-25-grounding', 'gemini-2.5-pro-exp-03-25-wo-search', 'gpt-4o-mini-search-preview', 'gpt-4o-search-preview', 'gpt-4o-search-preview-high', 'gpt-4o-search-preview-high-loc', 'sonar', 'sonar-pro', 'sonar-pro-high', 'sonar-reasoning', 'sonar-reasoning-pro-high']
GROUP_BY = "primary_intent"  # Options: "language" or "primary_intent"
N_BOOT = 200

In [None]:
def load_df(group_by=GROUP_BY):
    ds = load_dataset("lmarena-ai/search-arena-24k", split="test")
    df = ds.to_pandas()
    keep = ['model_a','model_b','winner','languages','timestamp', 'primary_intent']
    df = df[keep].copy()
    # Filter rows where languages is a list with exactly one element
    df = df[df["languages"].apply(lambda x: isinstance(x, (list, tuple, np.ndarray)) and len(x) == 1)].copy()
    def extract_lang(x):
        if isinstance(x, (list, tuple, np.ndarray)) and len(x) == 1:
            val = x[0]
            if isinstance(val, np.ndarray):
                val = val.item() if val.size == 1 else str(val[0])
            return str(val)
        return str(x)
    df["language"] = df["languages"].apply(extract_lang)
    if group_by == "language":
        df["group"] = df["language"]
    elif group_by == "primary_intent":
        df["group"] = df["primary_intent"]
    else:
        raise ValueError(f"Unknown group_by: {group_by}. Must be 'language' or 'primary_intent'")
    return df

def top_groups(df, group_col="group", k=4):
    vc = df[group_col].value_counts(dropna=False)
    groups = vc.head(k).index.tolist()
    print(groups)
    return groups, vc

def top_languages(df, k=4):
    return top_groups(df, "language", k)


def build_margins(df, groups, models, alpha=1.0, group_col="group"):
    m = len(models); K = len(groups)
    idx = {name:i for i,name in enumerate(models)}
    df = df[df[group_col].isin(groups)].copy()
    df = df[df["model_a"].isin(models) & df["model_b"].isin(models)].copy()
    counts = df[group_col].value_counts().reindex(groups).fillna(0).astype(int)
    tot = int(counts.sum())
    if tot == 0: raise ValueError("No rows after filtering; check groups/models.")
    w0 = (counts/tot).to_numpy(float)
    M_list = []
    df = df.copy()
    df["idx_a"] = df["model_a"].map(idx)
    df["idx_b"] = df["model_b"].map(idx)
    df = df.dropna(subset=["idx_a", "idx_b"]).copy()
    
    M_list = []
    for grp in groups:
        rows = df[df[group_col]==grp]
        if len(rows) == 0:
            M_list.append(np.zeros((m,m), float))
            continue
        win = np.zeros((m,m), float)
        idx_a = rows["idx_a"].values.astype(int)
        idx_b = rows["idx_b"].values.astype(int)
        winner = rows["winner"].values
        mask_a = (winner == "model_a")
        np.add.at(win, (idx_a[mask_a], idx_b[mask_a]), 1.0)
        mask_b = (winner == "model_b")
        np.add.at(win, (idx_b[mask_b], idx_a[mask_b]), 1.0)

        M = np.zeros((m,m), float)
        for i in range(m):
            for j in range(i+1, m):
                tot_ij = win[i,j] + win[j,i]
                if tot_ij > 0:
                    mij = (win[i,j] - win[j,i]) / (tot_ij + 2.0*alpha)
                else:
                    mij = 0.0
                M[i,j] = mij; M[j,i] = -mij
        M_list.append(M)
    return M_list, w0, counts, df

In [4]:
def solve_drml_tv(M_list, w0, rho, solvers=("GUROBI","MOSEK","GLPK","ECOS")):
    K = len(M_list); m = M_list[0].shape[0]
    p = cp.Variable(m, nonneg=True); t = cp.Variable()
    mu = cp.Variable(m); lam = cp.Variable(m, nonneg=True)
    gamma = cp.Variable((m,K))
    cons = [cp.sum(p)==1]
    for a in range(m):
        cons += [t <= mu[a] - 2.0*rho*lam[a] + w0 @ gamma[a,:]]
        for k in range(K):
            Mk = M_list[k]
            cons += [mu[a] + gamma[a,k] <= p @ Mk[:,a],
                     gamma[a,k] <= lam[a],
                     gamma[a,k] >= -lam[a]]
    prob = cp.Problem(cp.Maximize(t), cons)
    last = None
    for s in solvers:
        try:
            prob.solve(solver=getattr(cp,s), verbose=False)
            if prob.status in ("optimal","optimal_inaccurate"): break
        except Exception as e:
            last = e
    if prob.status not in ("optimal","optimal_inaccurate"):
        raise RuntimeError(f"LP not solved. status={prob.status}, last={last}")
    pval = np.array(p.value).reshape(-1)
    pval[pval<0]=0
    if pval.sum()>0: pval /= pval.sum()
    return pval, float(t.value)


In [None]:
def per_group_winrate(p, M_list):
    M_stack = np.stack(M_list, axis=0)  # (K, m, m)
    return per_group_winrate_vectorized(p, M_stack)

def per_group_winrate_vectorized(p, M_stack):
    pM = np.einsum('i,kij->kj', p, M_stack)
    margins = np.min(pM, axis=1)
    wr = 0.5 * (1.0 + margins)
    return wr

def bootstrap_df(df_sub, groups, seed, group_col="group"):
    rng = np.random.default_rng(seed)
    parts = []
    for grp in groups:
        part = df_sub[df_sub[group_col]==grp]
        n = len(part)
        if n == 0: continue
        idx = rng.integers(0, n, size=n)
        parts.append(part.iloc[idx])
    return pd.concat(parts, ignore_index=True)


In [6]:
def bootstrap_margins(df, models, n_boot=200, seed=0, alpha=1.0, group_by=GROUP_BY):
    if group_by == "language":
        df["group"] = df["language"]
        groups, _ = top_groups(df, "group", k=4)
    elif group_by == "primary_intent":
        df["group"] = df["primary_intent"]
        groups, _ = top_groups(df, "group", k=4)
    else:
        raise ValueError(f"Unknown group_by: {group_by}")

    M0, w0, counts, df_sub = build_margins(df, groups, models, alpha=alpha, group_col="group")
    print("groups:", groups)
    print("counts:", dict(zip(groups, counts.tolist())))
    print("w0:", dict(zip(groups, w0.tolist())))
    boot_M = []
    for b in range(n_boot):
        df_b = bootstrap_df(df_sub, groups, seed + 100000*b, group_col="group")
        Mb, _, _, _ = build_margins(df_b, groups, models, alpha=alpha, group_col="group")
        boot_M.append(Mb)
        if (b+1) % max(1, n_boot//10) == 0: print(f"boot {b+1}/{n_boot}")
    return groups, w0, M0, boot_M

def compute_from_bootstrap(groups, w0, boot_M, rhos):
    K = len(groups)
    boot_wr = {(rho,k): [] for rho in rhos for k in range(K)}
    boot_wr_overall = {rho: [] for rho in rhos}
    boot_p = {rho: [] for rho in rhos}
    boot_v = {rho: [] for rho in rhos}
    n_boot = len(boot_M)
    w0_arr = np.array(w0, float)
    for b, Mb in enumerate(boot_M):
        Mb_stack = np.stack(Mb, axis=0)
        for rho in rhos:
            p, v = solve_drml_tv(Mb, w0, rho)
            boot_p[rho].append(p)
            boot_v[rho].append(v)  
            wr = per_group_winrate_vectorized(p, Mb_stack)
            for k in range(K): boot_wr[(rho,k)].append(float(wr[k]))
            M_pooled = np.einsum('k,kij->ij', w0_arr, Mb_stack)
            overall_margin = float(np.min(p @ M_pooled))
            overall_wr = 0.5 * (1 + overall_margin)
            boot_wr_overall[rho].append(float(overall_wr))
        if (b+1) % max(1, n_boot//10) == 0: print(f"compute {b+1}/{n_boot}")
    rows = []
    for rho in rhos:
        for k, lang in enumerate(groups):
            arr = np.array(boot_wr[(rho,k)], float)
            se = float(arr.std() / np.sqrt(len(arr)))
            rows.append({"rho": rho, "group": lang, "boot_mean": float(arr.mean()), "boot_se": se})
        arr_overall = np.array(boot_wr_overall[rho], float)
        se_overall = float(arr_overall.std() / np.sqrt(len(arr_overall)))
        rows.append({"rho": rho, "group": "overall", "boot_mean": float(arr_overall.mean()), "boot_se": se_overall})
    ci = pd.DataFrame(rows).sort_values(["group","rho"])
    return ci, boot_p, boot_v

In [None]:
df = load_df(group_by=GROUP_BY)
rhos = np.linspace(0.0, 1.0, 11)
print("Generating bootstrap margin matrices...")
groups, w0, M0, boot_M = bootstrap_margins(df, MODELS, n_boot=N_BOOT, seed=1, alpha=2, group_by=GROUP_BY)
print("Computing results from bootstrap margin matrices...")
ci, boot_p, boot_v = compute_from_bootstrap(groups, w0, boot_M, rhos)

Generating bootstrap margin matrices...
['Factual Lookup', 'Info Synthesis', 'Recommendation', 'Analysis']
groups: ['Factual Lookup', 'Info Synthesis', 'Recommendation', 'Analysis']
counts: {'Factual Lookup': 4269, 'Info Synthesis': 3931, 'Recommendation': 2441, 'Analysis': 2348}
w0: {'Factual Lookup': 0.3286627146046655, 'Info Synthesis': 0.30264069597351606, 'Recommendation': 0.18792824697821234, 'Analysis': 0.18076834244360612}
boot 20/200
boot 40/200
boot 60/200
boot 80/200
boot 100/200
boot 120/200
boot 140/200
boot 160/200
boot 180/200
boot 200/200
Computing results from bootstrap margin matrices...
compute 20/200
compute 40/200
compute 60/200
compute 80/200
compute 100/200
compute 120/200
compute 140/200
compute 160/200
compute 180/200
compute 200/200


In [8]:
def plot_group_curves(ci, groups, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"group_winrate_vs_rho_{GROUP_BY}.pdf"
    sns.set_style("whitegrid")
    sns.set_palette("husl")
    
    plt.rcParams.update({
        "font.size": 11,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.5,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
    })
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    default_colors = ['#E74C3C', '#3498DB', '#F39C12', '#27AE60', '#9B59B6', '#1ABC9C', '#E67E22']
    colors = {'overall': '#2C3E50'}
    for i, grp in enumerate(groups):
        if grp not in colors:
            colors[grp] = default_colors[i % len(default_colors)]
    
    desired_order = ['overall'] + list(groups)
    
    for group in desired_order:
        if group == 'overall':
            d = ci[ci["group"] == "overall"].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["boot_mean"].to_numpy(float)
                se = d["boot_se"].to_numpy(float)
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5, 
                       label="Overall", linestyle="--", color=colors['overall'], alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors['overall'])
        elif group in groups:
            d = ci[ci["group"] == group].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["boot_mean"].to_numpy(float)
                se = d["boot_se"].to_numpy(float)
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5, 
                       label=group, color=colors.get(group, '#95A5A6'), alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors.get(group, '#95A5A6'))
    
    ax.set_xlabel(r"$\rho$", fontsize=12, labelpad=6)
    ax.set_ylabel("Win Rate (%)", fontsize=12, labelpad=6)
    
    ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.4)
    ax.set_axisbelow(True)
    
    ax.legend(frameon=True, loc='best', framealpha=0.95, facecolor='white', 
             edgecolor='lightgray', borderpad=0.8, labelspacing=0.6)
    
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print("saved:", out_pdf)

In [9]:
plot_group_curves(ci, groups)

saved: group_winrate_vs_rho_primary_intent.pdf


In [10]:
def bootstrap_train_test_splits(df, groups, models, train_frac=0.8, n_boot=100, seed=0, group_col="group"):
    df_filtered = df[df[group_col].isin(groups)].copy()
    df_filtered = df_filtered[df_filtered["model_a"].isin(models) & df_filtered["model_b"].isin(models)].copy()
    
    splits = []
    for b in range(n_boot):
        rng = np.random.default_rng(seed + 100000*b)
        n_total = len(df_filtered)
        boot_idx = rng.integers(0, n_total, size=n_total)
        df_boot = df_filtered.iloc[boot_idx].reset_index(drop=True)
        
        df_train_list = []
        df_test_list = []
        for grp in groups:
            grp_df = df_boot[df_boot[group_col] == grp]
            n = len(grp_df)
            if n == 0:
                continue
            n_train = int(train_frac * n)
            idx = np.arange(n)
            rng.shuffle(idx)
            train_idx = idx[:n_train]
            test_idx = idx[n_train:]
            df_train_list.append(grp_df.iloc[train_idx])
            df_test_list.append(grp_df.iloc[test_idx])
        if len(df_train_list) == 0:
            continue
        df_train = pd.concat(df_train_list, ignore_index=True)
        df_test = pd.concat(df_test_list, ignore_index=True)
        splits.append((df_train, df_test))
        if (b+1) % max(1, n_boot//10) == 0: print(f"boot {b+1}/{n_boot}")
    return splits

def compute_generalization(splits, groups, models, rhos, alpha=1.0, group_col="group"):
    K = len(groups)
    boot_wr_gaps = {(rho, grp): [] for rho in rhos for grp in groups + ["overall"]}
    
    for df_train, df_test in splits:
        M_train, w0_train, _, _ = build_margins(df_train, groups, models, alpha=alpha, group_col=group_col)
        M_test, w0_test, _, _ = build_margins(df_test, groups, models, alpha=alpha, group_col=group_col)
        
        M_train_stack = np.stack(M_train, axis=0)
        M_test_stack = np.stack(M_test, axis=0)
        w0_train_arr = np.array(w0_train, float)
        w0_test_arr = np.array(w0_test, float)
        
        for rho in rhos:
            p_rho, _ = solve_drml_tv(M_train, w0_train, rho)
            
            wr_test = per_group_winrate_vectorized(p_rho, M_test_stack)
            for k, grp in enumerate(groups): boot_wr_gaps[(rho, grp)].append(float(wr_test[k]))
            
            M_pooled_test = np.einsum('k,kij->ij', w0_test_arr, M_test_stack)
            v_test_overall = float(np.min(p_rho @ M_pooled_test))
            wr_test_overall = 0.5 * (1 + v_test_overall)
            boot_wr_gaps[(rho, "overall")].append(float(wr_test_overall))
    
    gap_rows = []
    for rho in rhos:
        for grp in groups + ["overall"]:
            arr = np.array(boot_wr_gaps[(rho, grp)], float)
            if len(arr) > 0:
                se = float(arr.std() / np.sqrt(len(arr)))
                gap_rows.append({
                    "rho": rho,
                    "group": grp,
                    "wr_gap_mean": float(arr.mean()),
                    "wr_gap_se": se
                })
    gap_df = pd.DataFrame(gap_rows)
    return gap_df, boot_wr_gaps

In [None]:
print("Generating bootstrap train/test splits...")
splits = bootstrap_train_test_splits(df, groups, MODELS, train_frac=0.8, n_boot=N_BOOT, seed=42, group_col="group")
print("Computing generalization gaps from splits...")
gap_df, _ = compute_generalization(splits, groups, MODELS, rhos, alpha=1.0, group_col="group")

Generating bootstrap train/test splits...
boot 20/200
boot 40/200
boot 60/200
boot 80/200
boot 100/200
boot 120/200
boot 140/200
boot 160/200
boot 180/200
boot 200/200
Computing generalization gaps from splits...


In [12]:
def plot_generalization(gap_df, groups, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"generalization_gaps_{GROUP_BY}.pdf"
    lang_names = {'en': 'English', 'pl': 'Polish', 'ru': 'Russian', 'zh': 'Chinese', 'overall': 'Overall'}
    
    sns.set_style("whitegrid")
    sns.set_palette("husl")
    
    plt.rcParams.update({
        "font.size": 11,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.5,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
    })
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    default_colors = ['#E74C3C', '#3498DB', '#F39C12', '#27AE60', '#9B59B6', '#1ABC9C', '#E67E22']
    colors = {'overall': '#2C3E50'}
    for i, grp in enumerate(groups):
        if grp not in colors:
            colors[grp] = default_colors[i % len(default_colors)]
    
    desired_order = ['overall'] + list(groups)
    
    for group in desired_order:
        if group == 'overall':
            d = gap_df[gap_df["group"] == "overall"].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["wr_gap_mean"].to_numpy(float)
                se = d["wr_gap_se"].to_numpy(float)
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5,
                       label="Overall", linestyle="--", color=colors['overall'], alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors['overall'])
        elif group in groups:
            d = gap_df[gap_df["group"] == group].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["wr_gap_mean"].to_numpy(float)
                se = d["wr_gap_se"].to_numpy(float)
                group_label = lang_names.get(group, group)
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5,
                       label=group_label, color=colors.get(group, '#95A5A6'), alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors.get(group, '#95A5A6'))
    
    ax.set_xlabel(r"$\rho$", fontsize=12, labelpad=6)
    ax.set_ylabel("Win rate on held-out data (%)", fontsize=12, labelpad=6)
    
    ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.4)
    ax.set_axisbelow(True)
    
    ax.legend(frameon=True, loc='best', framealpha=0.95, facecolor='white', 
             edgecolor='lightgray', borderpad=0.8, labelspacing=0.6)
    
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print("saved:", out_pdf)

def print_generalization_gaps(gap_df, groups, rhos):
    lang_names = {'en': 'English', 'pl': 'Polish', 'ru': 'Russian', 'zh': 'Chinese', 'overall': 'Overall'}
    gap_df = gap_df.copy()
    gap_df["formatted"] = gap_df.apply(
        lambda row: f"{row['wr_gap_mean']:.4f} ± {row['wr_gap_se']:.4f}", axis=1
    )
    table = gap_df.pivot(index="group", columns="rho", values="formatted")
    all_groups = groups + ["overall"]
    table = table.reindex(all_groups)
    table.index = [lang_names.get(grp, grp) for grp in table.index]
    display(table)
    return table

In [13]:
print("\nGeneralization gaps:")
print_generalization_gaps(gap_df, groups, rhos)
plot_generalization(gap_df, groups)


Generalization gaps:


rho,0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0
Factual Lookup,0.3564 ± 0.0051,0.3751 ± 0.0044,0.3860 ± 0.0039,0.3908 ± 0.0036,0.3924 ± 0.0035,0.3925 ± 0.0034,0.3935 ± 0.0034,0.3948 ± 0.0034,0.3931 ± 0.0034,0.3927 ± 0.0034,0.3927 ± 0.0034
Info Synthesis,0.3395 ± 0.0060,0.3637 ± 0.0050,0.3786 ± 0.0042,0.3873 ± 0.0038,0.3897 ± 0.0036,0.3932 ± 0.0032,0.3948 ± 0.0031,0.3965 ± 0.0029,0.3964 ± 0.0029,0.3961 ± 0.0029,0.3961 ± 0.0029
Recommendation,0.3058 ± 0.0056,0.3336 ± 0.0051,0.3563 ± 0.0045,0.3685 ± 0.0041,0.3750 ± 0.0037,0.3811 ± 0.0036,0.3835 ± 0.0035,0.3856 ± 0.0034,0.3868 ± 0.0034,0.3871 ± 0.0034,0.3871 ± 0.0034
Analysis,0.3511 ± 0.0056,0.3694 ± 0.0048,0.3848 ± 0.0042,0.3935 ± 0.0039,0.3955 ± 0.0037,0.3957 ± 0.0035,0.3968 ± 0.0033,0.3977 ± 0.0031,0.3975 ± 0.0032,0.3977 ± 0.0032,0.3977 ± 0.0032
Overall,0.4365 ± 0.0030,0.4434 ± 0.0026,0.4474 ± 0.0021,0.4495 ± 0.0019,0.4507 ± 0.0019,0.4501 ± 0.0018,0.4502 ± 0.0017,0.4504 ± 0.0017,0.4496 ± 0.0017,0.4492 ± 0.0017,0.4492 ± 0.0017


saved: generalization_gaps_primary_intent.pdf


In [14]:
def print_boot_p(boot_p, models, rhos, thresh=0.02):
    rows = []
    for rho in rhos:
        P = np.vstack(boot_p[rho])
        mean = P.mean(axis=0)
        se = P.std(axis=0) / np.sqrt(P.shape[0])
        for i, name in enumerate(models):
            rows.append({"rho": rho, "model": name, "mean": mean[i], "se": se[i]})
    wdf = pd.DataFrame(rows)
    keep = wdf.groupby("model")["mean"].max()
    keep = keep[keep > thresh].index
    wdf = wdf[wdf["model"].isin(keep)].copy()
    wdf["mean%"] = (100*wdf["mean"]).round(2)
    wdf["se%"] = (100*wdf["se"]).round(2)
    
    wdf["formatted"] = wdf.apply(lambda row: f"{row['mean%']:.2f} ± {row['se%']:.2f}", axis=1)
    
    table = wdf.pivot(index="model", columns="rho", values="formatted")
    
    rho0_means = wdf[wdf["rho"] == rhos[0]].set_index("model")["mean"].sort_values(ascending=False)
    table = table.reindex(rho0_means.index)
    
    print(f"(shown if mean> {100*thresh:.1f}% for some rho)")
    display(table)
    return wdf

wdf = print_boot_p(boot_p, MODELS, rhos, thresh=0.02)


(shown if mean> 2.0% for some rho)


rho,0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
sonar-reasoning-pro-high,48.91 ± 2.91,46.05 ± 2.15,42.37 ± 1.60,36.86 ± 1.34,34.71 ± 1.17,33.66 ± 1.08,33.60 ± 1.05,34.25 ± 1.02,34.41 ± 1.04,34.12 ± 1.04,34.12 ± 1.04
gemini-2.5-pro-exp-03-25-grounding,34.79 ± 2.68,34.77 ± 2.13,32.06 ± 1.75,31.23 ± 1.54,30.19 ± 1.44,28.88 ± 1.32,27.55 ± 1.25,26.36 ± 1.20,25.41 ± 1.18,25.50 ± 1.19,25.50 ± 1.19
sonar-pro-high,9.05 ± 1.49,11.65 ± 1.23,14.94 ± 0.99,17.46 ± 0.87,18.96 ± 0.85,19.55 ± 0.83,19.89 ± 0.81,19.90 ± 0.78,20.74 ± 0.79,20.85 ± 0.79,20.85 ± 0.79
sonar,4.80 ± 0.91,5.43 ± 0.82,7.10 ± 0.80,8.73 ± 0.81,9.59 ± 0.78,10.16 ± 0.72,9.85 ± 0.69,9.30 ± 0.67,8.75 ± 0.65,8.77 ± 0.65,8.77 ± 0.65
sonar-reasoning,1.73 ± 0.39,1.51 ± 0.28,2.22 ± 0.35,3.41 ± 0.45,3.25 ± 0.44,2.92 ± 0.39,2.81 ± 0.38,2.80 ± 0.37,2.66 ± 0.37,2.66 ± 0.37,2.66 ± 0.37
sonar-pro,0.32 ± 0.13,0.42 ± 0.14,0.82 ± 0.22,1.20 ± 0.26,1.77 ± 0.29,2.14 ± 0.31,2.48 ± 0.33,2.70 ± 0.34,2.65 ± 0.33,2.68 ± 0.33,2.68 ± 0.33
gemini-2.5-pro-exp-03-25-wo-search,0.26 ± 0.12,0.10 ± 0.05,0.26 ± 0.09,0.60 ± 0.14,0.77 ± 0.16,1.36 ± 0.21,1.91 ± 0.26,2.24 ± 0.28,2.56 ± 0.30,2.54 ± 0.30,2.54 ± 0.30
gemini-2.5-flash-preview-04-17-grounding,0.13 ± 0.13,0.09 ± 0.06,0.23 ± 0.09,0.50 ± 0.15,0.75 ± 0.18,1.23 ± 0.25,1.76 ± 0.28,2.22 ± 0.31,2.58 ± 0.33,2.64 ± 0.34,2.64 ± 0.34


In [15]:
def plot_model_performance_grid(wdf, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"model_performance_grid_{GROUP_BY}.pdf"
    sns.set_style("white")
    
    plt.rcParams.update({
        "font.size": 10,
        "axes.labelsize": 11,
        "xtick.labelsize": 9,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.4,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0,
    })
    
    unique_rhos = sorted(wdf["rho"].unique())
    rhos_to_plot =[0.0, 0.4, 0.7, 1.0]
    
    model_names = {
        'gemini-2.0-flash-grounding': 'Gemini 2.0 Flash',
        'gemini-2.5-flash-preview-04-17-grounding': 'Gemini 2.5 Flash',
        'gemini-2.5-pro-exp-03-25-grounding': 'Gemini 2.5 Pro',
        'gemini-2.5-pro-exp-03-25-wo-search': 'Gemini 2.5 Pro (no search)',
        'gpt-4o-mini-search-preview': 'GPT-4o Mini',
        'gpt-4o-search-preview': 'GPT-4o',
        'gpt-4o-search-preview-high': 'GPT-4o High',
        'gpt-4o-search-preview-high-loc': 'GPT-4o High Loc',
        'sonar': 'Sonar',
        'sonar-pro': 'Sonar Pro',
        'sonar-pro-high': 'Sonar Pro High',
        'sonar-reasoning': 'Sonar Reasoning',
        'sonar-reasoning-pro-high': 'Sonar Reasoning Pro High',
    }
    
    colors = {
        'sonar-reasoning-pro-high': '#E74C3C',
        'gemini-2.5-pro-exp-03-25-grounding': '#3498DB',
        'sonar-pro-high': '#F39C12',
        'sonar': '#27AE60',
        'sonar-pro': '#9B59B6',
        'sonar-reasoning': '#1ABC9C',
        'gemini-2.5-pro-exp-03-25-wo-search': '#E67E22',
        'gemini-2.5-flash-preview-04-17-grounding': '#E91E63',
    }
    
    fig, axes = plt.subplots(2, 2, figsize=(9.5, 10))
    axes = axes.flatten()
    
    models_ordered = wdf[np.isclose(wdf["rho"], 0.0)].sort_values("mean", ascending=True)["model"].values
    
    for idx, rho in enumerate(rhos_to_plot):
        ax = axes[idx]
        
        data = wdf[np.isclose(wdf["rho"], rho)].copy()
        data = data.set_index("model").reindex(models_ordered).reset_index()
        data = data.dropna()
        
        bar_spacing = 0.65
        y_pos = np.arange(len(data)) * bar_spacing
        bar_colors = [colors.get(m, '#95A5A6') for m in data["model"]]
        
        ax.barh(y_pos, data["mean%"], xerr=data["se%"], 
               color=bar_colors,
               error_kw={'elinewidth': 1.2, 'capsize': 2.5, 'alpha': 0.7},
               height=0.5 * bar_spacing, alpha=0.85, edgecolor='white', linewidth=0.5)
        
        ax.set_xlim(0, 75)
        ax.set_xticks([0, 20, 40, 60, 80])
        
        ax.set_axisbelow(True)
        ax.grid(True, axis='x', alpha=0.25, linestyle='-', linewidth=0.5, color='gray')
        ax.grid(True, axis='y', alpha=0.15, linestyle='-', linewidth=0.3, color='gray')
        
        if idx in [0, 2]:
            ax.set_yticks(y_pos)
            clean_names = [model_names.get(m, m) for m in data["model"]]
            ax.set_yticklabels(clean_names, fontsize=12)
        else:
            ax.set_yticks([])
        
        if idx in [2, 3]:
            ax.set_xlabel("Probability (%)", fontsize=14, labelpad=4)
            ax.set_xticklabels(['0', '20', '40', '60', '80'], fontsize=14)
        else:
            ax.set_xticklabels([])
        
        rho_label = f'ρ = {rho:.1f}'
        ax.text(0.95, 0.1, rho_label, transform=ax.transAxes,
               fontsize=16, fontweight='bold', ha='right', va='bottom',
               bbox=dict(boxstyle='round,pad=0.4', facecolor='white', 
                        edgecolor='lightgray', alpha=0.8, linewidth=0.5))
        
        ax.spines['left'].set_linewidth(0.8)
        ax.spines['bottom'].set_linewidth(0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    fig.tight_layout(pad=1.2)
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print(f"saved: {out_pdf}")

In [16]:
plot_model_performance_grid(wdf)

saved: model_performance_grid_primary_intent.pdf
