In [28]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, cvxpy as cp
from datasets import load_dataset
import seaborn as sns

In [29]:
MODELS = ['anthropic/claude-3.7-sonnet', 'anthropic/claude-opus-4', 'anthropic/claude-sonnet-4', 'cohere/command-a', 'cohere/command-r7b-12-2024', 'deepseek/deepseek-r1-0528', 'google/gemini-2.0-flash-001', 'google/gemini-2.5-flash', 'google/gemini-2.5-pro', 'mistralai/mistral-nemo', 'openai/o1', 'openai/o3-mini', 'openai/o4-mini', 'x-ai/grok-3', 'x-ai/grok-4']
GROUP_BY = "ethnic_group"  # Options: "political_affiliation", "ethnic_group", or "metric"
ALPHA = 1.0
N_BOOT = 200

In [30]:
def load_df(group_by=None):
    if group_by is None:
        group_by = GROUP_BY
    ds = load_dataset("ProlificAI/humaine-evaluation-dataset", name="feedback_comparisons", split="train")
    df = ds.to_pandas()
    df = df[df["choice"] != "tie"].copy()
    if group_by == "political_affiliation":
        df["group"] = df["political_affiliation"]
    elif group_by == "ethnic_group":
        df["group"] = df["ethnic_group"]
    elif group_by == "metric":
        df["group"] = df["metric"]
        df = df[df["group"] != "overall winner"].copy()
    else:
        raise ValueError(f"Unknown group_by: {group_by}. Must be 'ethnic_group', 'political_affiliation', or 'metric'")
    df = df[df["group"].notna()].copy()
    return df

def top_groups(df, group_col="group", k=5):
    vc = df[group_col].value_counts(dropna=False)
    groups = vc.head(k).index.tolist()
    print(groups)
    return groups, vc

def build_margins(df, groups, models, alpha=0.0, group_col="group"):
    m = len(models); K = len(groups)
    idx = {name:i for i,name in enumerate(models)}
    df = df[df[group_col].isin(groups)].copy()
    df = df[df["model_a"].isin(models) & df["model_b"].isin(models)].copy()
    counts = df[group_col].value_counts().reindex(groups).fillna(0).astype(int)
    tot = int(counts.sum())
    if tot == 0: raise ValueError("No rows after filtering; check groups/models.")
    w0 = (counts/tot).to_numpy(float)
    M_list = []
    df = df.copy()
    df["idx_a"] = df["model_a"].map(idx)
    df["idx_b"] = df["model_b"].map(idx)
    df = df.dropna(subset=["idx_a", "idx_b"]).copy()
    
    M_list = []
    for grp in groups:
        rows = df[df[group_col]==grp]
        if len(rows) == 0:
            M_list.append(np.zeros((m,m), float))
            continue
        win = np.zeros((m,m), float)
        idx_a = rows["idx_a"].values.astype(int)
        idx_b = rows["idx_b"].values.astype(int)
        choice = rows["choice"].values
        mask_a = (choice == "A")
        np.add.at(win, (idx_a[mask_a], idx_b[mask_a]), 1.0)
        mask_b = (choice == "B")
        np.add.at(win, (idx_b[mask_b], idx_a[mask_b]), 1.0)
        M = np.zeros((m,m), float)
        for i in range(m):
            for j in range(i+1, m):
                tot_ij = win[i,j] + win[j,i]
                if tot_ij > 0:
                    mij = (win[i,j] - win[j,i]) / (tot_ij + 2.0*alpha)
                else:
                    mij = 0.0
                M[i,j] = mij; M[j,i] = -mij
        M_list.append(M)
    return M_list, w0, counts, df

In [31]:
def solve_drml_tv(M_list, w0, rho, solvers=("GUROBI","MOSEK","GLPK","ECOS")):
    K = len(M_list); m = M_list[0].shape[0]
    p = cp.Variable(m, nonneg=True); t = cp.Variable()
    mu = cp.Variable(m); lam = cp.Variable(m, nonneg=True)
    gamma = cp.Variable((m,K))
    cons = [cp.sum(p)==1]
    for a in range(m):
        cons += [t <= mu[a] - 2.0*rho*lam[a] + w0 @ gamma[a,:]]
        for k in range(K):
            Mk = M_list[k]
            cons += [mu[a] + gamma[a,k] <= p @ Mk[:,a],
                     gamma[a,k] <= lam[a],
                     gamma[a,k] >= -lam[a]]
    prob = cp.Problem(cp.Maximize(t), cons)
    last = None
    for s in solvers:
        try:
            prob.solve(solver=getattr(cp,s), verbose=False)
            if prob.status in ("optimal","optimal_inaccurate"): break
        except Exception as e:
            last = e
    if prob.status not in ("optimal","optimal_inaccurate"):
        raise RuntimeError(f"LP not solved. status={prob.status}, last={last}")
    pval = np.array(p.value).reshape(-1)
    pval[pval<0]=0
    if pval.sum()>0: pval /= pval.sum()
    return pval, float(t.value)

In [32]:
def per_group_winrate(p, M_list):
    M_stack = np.stack(M_list, axis=0)  # (K, m, m)
    return per_group_winrate_vectorized(p, M_stack)

def per_group_winrate_vectorized(p, M_stack):
    pM = np.einsum('i,kij->kj', p, M_stack)
    margins = np.min(pM, axis=1)
    wr = 0.5 * (1.0 + margins)
    return wr

def bootstrap_df(df_sub, groups, seed, group_col="group"):
    rng = np.random.default_rng(seed)
    parts = []
    for grp in groups:
        part = df_sub[df_sub[group_col]==grp]
        n = len(part)
        if n == 0: continue
        idx = rng.integers(0, n, size=n)
        parts.append(part.iloc[idx])
    return pd.concat(parts, ignore_index=True)

In [33]:
def bootstrap_margins(df, models, n_boot=200, seed=0, alpha=1.0, group_by=GROUP_BY):
    if group_by == "ethnic_group":
        df["group"] = df["ethnic_group"]
        groups, _ = top_groups(df, "group", k=4)
    elif group_by == "political_affiliation":
        df["group"] = df["political_affiliation"]
        groups, _ = top_groups(df, "group", k=4)
    else:
        raise ValueError(f"Unknown group_by: {group_by}")

    M0, w0, counts, df_sub = build_margins(df, groups, models, alpha=alpha, group_col="group")
    print("groups:", groups)
    print("counts:", dict(zip(groups, counts.tolist())))
    print("w0:", dict(zip(groups, w0.tolist())))
    boot_M = []
    for b in range(n_boot):
        df_b = bootstrap_df(df_sub, groups, seed + 100000*b, group_col="group")
        Mb, _, _, _ = build_margins(df_b, groups, models, alpha=alpha, group_col="group")
        boot_M.append(Mb)
        if (b+1) % max(1, n_boot//10) == 0: print(f"boot {b+1}/{n_boot}")
    return groups, w0, M0, boot_M

def compute_from_bootstrap(groups, w0, boot_M, rhos):
    K = len(groups)
    boot_wr = {(rho,k): [] for rho in rhos for k in range(K)}
    boot_wr_overall = {rho: [] for rho in rhos}
    boot_p = {rho: [] for rho in rhos}
    boot_v = {rho: [] for rho in rhos}
    n_boot = len(boot_M)
    w0_arr = np.array(w0, float)
    for b, Mb in enumerate(boot_M):
        Mb_stack = np.stack(Mb, axis=0)
        for rho in rhos:
            p, v = solve_drml_tv(Mb, w0, rho)
            boot_p[rho].append(p)
            boot_v[rho].append(v)  
            wr = per_group_winrate_vectorized(p, Mb_stack)
            for k in range(K): boot_wr[(rho,k)].append(float(wr[k]))
            M_pooled = np.einsum('k,kij->ij', w0_arr, Mb_stack)
            overall_margin = float(np.min(p @ M_pooled))
            overall_wr = 0.5 * (1 + overall_margin)
            boot_wr_overall[rho].append(float(overall_wr))
        if (b+1) % max(1, n_boot//10) == 0: print(f"compute {b+1}/{n_boot}")
    rows = []
    for rho in rhos:
        for k, lang in enumerate(groups):
            arr = np.array(boot_wr[(rho,k)], float)
            se = float(arr.std() / np.sqrt(len(arr)))
            rows.append({"rho": rho, "group": lang, "boot_mean": float(arr.mean()), "boot_se": se})
        arr_overall = np.array(boot_wr_overall[rho], float)
        se_overall = float(arr_overall.std() / np.sqrt(len(arr_overall)))
        rows.append({"rho": rho, "group": "overall", "boot_mean": float(arr_overall.mean()), "boot_se": se_overall})
    ci = pd.DataFrame(rows).sort_values(["group","rho"])
    return ci, boot_p, boot_v

In [34]:
df = load_df(group_by=GROUP_BY)
rhos = np.linspace(0.0, 1.0, 11)
print("Generating bootstrap margin matrices...")
groups, w0, M0, boot_M = bootstrap_margins(df, MODELS, n_boot=N_BOOT, seed=1, alpha=ALPHA, group_by=GROUP_BY)
print("Computing results from bootstrap margin matrices...")
ci, boot_p, boot_v = compute_from_bootstrap(groups, w0, boot_M, rhos)

Generating bootstrap margin matrices...
['White', 'Black', 'Asian', 'Mixed']
groups: ['White', 'Black', 'Asian', 'Mixed']
counts: {'White': 14266, 'Black': 4134, 'Asian': 3359, 'Mixed': 1392}
w0: {'White': 0.6162152822772234, 'Black': 0.17856680057016974, 'Asian': 0.14509092479806487, 'Mixed': 0.06012699235454192}
boot 20/200
boot 40/200
boot 60/200
boot 80/200
boot 100/200
boot 120/200
boot 140/200
boot 160/200
boot 180/200
boot 200/200
Computing results from bootstrap margin matrices...
compute 20/200
compute 40/200
compute 60/200
compute 80/200
compute 100/200
compute 120/200
compute 140/200
compute 160/200
compute 180/200
compute 200/200


In [35]:
def plot_group_curves(ci, groups, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"group_winrate_vs_rho_{GROUP_BY}_ALPHA={ALPHA}.pdf"
    sns.set_style("whitegrid")
    sns.set_palette("husl")
    
    plt.rcParams.update({
        "font.size": 11,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.5,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
    })
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    default_colors = ['#E74C3C', '#3498DB', '#F39C12', '#27AE60', '#9B59B6', '#1ABC9C', '#E67E22']
    colors = {'overall': '#2C3E50'}
    for i, grp in enumerate(groups):
        if grp not in colors:
            colors[grp] = default_colors[i % len(default_colors)]
    
    desired_order = ['overall'] + list(groups)
    
    for group in desired_order:
        if group == 'overall':
            d = ci[ci["group"] == "overall"].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["boot_mean"].to_numpy(float) * 100.0
                se = d["boot_se"].to_numpy(float) * 100.0
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5, 
                       label="Overall", linestyle="--", color=colors['overall'], alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors['overall'])
        elif group in groups:
            d = ci[ci["group"] == group].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["boot_mean"].to_numpy(float) * 100.0
                se = d["boot_se"].to_numpy(float) * 100.0
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5, 
                       label=group, color=colors.get(group, '#95A5A6'), alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors.get(group, '#95A5A6'))
    
    ax.set_xlabel(r"$\rho$", fontsize=12, labelpad=6)
    ax.set_ylabel("Win Rate (%)", fontsize=12, labelpad=6)
    
    ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.4)
    ax.set_axisbelow(True)
    
    ax.legend(frameon=True, loc='best', framealpha=0.95, facecolor='white', edgecolor='lightgray', borderpad=0.5, labelspacing=0.5, ncol=1, columnspacing=0.8)
    
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print("saved:", out_pdf)

In [36]:
plot_group_curves(ci, groups)

saved: group_winrate_vs_rho_ethnic_group_ALPHA=1.0.pdf


In [37]:
def bootstrap_train_test_splits(df, groups, models, train_frac=0.8, n_boot=100, seed=0, group_col="group"):
    df_filtered = df[df[group_col].isin(groups)].copy()
    df_filtered = df_filtered[df_filtered["model_a"].isin(models) & df_filtered["model_b"].isin(models)].copy()
    
    splits = []
    for b in range(n_boot):
        rng = np.random.default_rng(seed + 100000*b)
        n_total = len(df_filtered)
        boot_idx = rng.integers(0, n_total, size=n_total)
        df_boot = df_filtered.iloc[boot_idx].reset_index(drop=True)
        
        df_train_list = []
        df_test_list = []
        for grp in groups:
            grp_df = df_boot[df_boot[group_col] == grp]
            n = len(grp_df)
            if n == 0:
                continue
            n_train = int(train_frac * n)
            idx = np.arange(n)
            rng.shuffle(idx)
            train_idx = idx[:n_train]
            test_idx = idx[n_train:]
            df_train_list.append(grp_df.iloc[train_idx])
            df_test_list.append(grp_df.iloc[test_idx])
        if len(df_train_list) == 0:
            continue
        df_train = pd.concat(df_train_list, ignore_index=True)
        df_test = pd.concat(df_test_list, ignore_index=True)
        splits.append((df_train, df_test))
        if (b+1) % max(1, n_boot//10) == 0: print(f"boot {b+1}/{n_boot}")
    return splits

def compute_generalization(splits, groups, models, rhos, alpha=1.0, group_col="group"):
    K = len(groups)
    boot_wr_train = {(rho, grp): [] for rho in rhos for grp in groups + ["overall"]}
    boot_wr_test = {(rho, grp): [] for rho in rhos for grp in groups + ["overall"]}
    
    for df_train, df_test in splits:
        M_train, w0_train, _, _ = build_margins(df_train, groups, models, alpha=alpha, group_col=group_col)
        M_test, w0_test, _, _ = build_margins(df_test, groups, models, alpha=alpha, group_col=group_col)
        
        M_train_stack = np.stack(M_train, axis=0)
        w0_train_arr = np.array(w0_train, float)
        M_test_stack = np.stack(M_test, axis=0)
        w0_test_arr = np.array(w0_test, float)
        
        for rho in rhos:
            p_rho, _ = solve_drml_tv(M_train, w0_train, rho)
            
            wr_train = per_group_winrate_vectorized(p_rho, M_train_stack)
            for k, grp in enumerate(groups): 
                boot_wr_train[(rho, grp)].append(float(wr_train[k]))
            
            M_pooled_train = np.einsum('k,kij->ij', w0_train_arr, M_train_stack)
            v_train_overall = float(np.min(p_rho @ M_pooled_train))
            wr_train_overall = 0.5 * (1 + v_train_overall)
            boot_wr_train[(rho, "overall")].append(float(wr_train_overall))
            
            wr_test = per_group_winrate_vectorized(p_rho, M_test_stack)
            for k, grp in enumerate(groups): 
                boot_wr_test[(rho, grp)].append(float(wr_test[k]))
            
            M_pooled_test = np.einsum('k,kij->ij', w0_test_arr, M_test_stack)
            v_test_overall = float(np.min(p_rho @ M_pooled_test))
            wr_test_overall = 0.5 * (1 + v_test_overall)
            boot_wr_test[(rho, "overall")].append(float(wr_test_overall))
    
    gap_rows = []
    for rho in rhos:
        for grp in groups + ["overall"]:
            arr_train = np.array(boot_wr_train[(rho, grp)], float)
            arr_test = np.array(boot_wr_test[(rho, grp)], float)
            if len(arr_train) > 0 and len(arr_test) > 0:
                train_mean = float(arr_train.mean())
                test_mean = float(arr_test.mean())
                gap = train_mean - test_mean
                se_train = float(arr_train.std() / np.sqrt(len(arr_train)))
                se_test = float(arr_test.std() / np.sqrt(len(arr_test)))
                se_gap = float(np.sqrt(arr_train.var() + arr_test.var()) / np.sqrt(len(arr_train)))
                gap_rows.append({
                    "rho": rho,
                    "group": grp,
                    "wr_train_mean": train_mean,
                    "wr_test_mean": test_mean,
                    "wr_gap_mean": gap,
                    "wr_train_se": se_train,
                    "wr_test_se": se_test,
                    "wr_gap_se": se_gap
                })
    gap_df = pd.DataFrame(gap_rows)
    return gap_df, boot_wr_train, boot_wr_test

In [38]:
print("Generating bootstrap train/test splits...")
splits = bootstrap_train_test_splits(df, groups, MODELS, train_frac=0.8, n_boot=N_BOOT, seed=42, group_col="group")
print("Computing generalization gaps from splits...")
gap_df, _, _ = compute_generalization(splits, groups, MODELS, rhos, alpha=ALPHA, group_col="group")

Generating bootstrap train/test splits...
boot 20/200
boot 40/200
boot 60/200
boot 80/200
boot 100/200
boot 120/200
boot 140/200
boot 160/200
boot 180/200
boot 200/200
Computing generalization gaps from splits...


In [39]:
def plot_winrate_heldout(gap_df, groups, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"winrate_heldout_{GROUP_BY}_ALPHA={ALPHA}.pdf"
    lang_names = {'en': 'English', 'pl': 'Polish', 'ru': 'Russian', 'zh': 'Chinese', 'overall': 'Overall'}
    
    sns.set_style("whitegrid")
    sns.set_palette("husl")
    
    plt.rcParams.update({
        "font.size": 11,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.5,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
    })
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    default_colors = ['#E74C3C', '#3498DB', '#F39C12', '#27AE60', '#9B59B6', '#1ABC9C', '#E67E22']
    colors = {'overall': '#2C3E50'}
    for i, grp in enumerate(groups):
        if grp not in colors:
            colors[grp] = default_colors[i % len(default_colors)]
    
    desired_order = ['overall'] + list(groups)
    
    for group in desired_order:
        if group == 'overall':
            d = gap_df[gap_df["group"] == "overall"].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["wr_test_mean"].to_numpy(float) * 100.0
                se = d["wr_test_se"].to_numpy(float) * 100.0
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5,
                       label="Overall", linestyle="--", color=colors['overall'], alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors['overall'])
        elif group in groups:
            d = gap_df[gap_df["group"] == group].sort_values("rho")
            if len(d) > 0:
                x = d["rho"].to_numpy(float)
                y = d["wr_test_mean"].to_numpy(float) * 100.0
                se = d["wr_test_se"].to_numpy(float) * 100.0
                group_label = lang_names.get(group, group)
                ax.plot(x, y, marker="o", linewidth=1.2, markersize=5,
                       label=group_label, color=colors.get(group, '#95A5A6'), alpha=0.9)
                ax.fill_between(x, y - se, y + se, alpha=0.15, color=colors.get(group, '#95A5A6'))
    
    ax.set_xlabel(r"$\rho$", fontsize=12, labelpad=6)
    ax.set_ylabel("", fontsize=12, labelpad=6)
    
    ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.4)
    ax.set_axisbelow(True)
    if GROUP_BY == "political_affiliation":
        ax.legend(frameon=True, loc=(0.7, 0.43), framealpha=0.95, facecolor='white', edgecolor='lightgray', borderpad=0.8, labelspacing=0.6, ncol=1, columnspacing=0.8)
    else:
        ax.legend(frameon=True, loc='best', framealpha=0.95, facecolor='white', edgecolor='lightgray', borderpad=0.8, labelspacing=0.6, ncol=1, columnspacing=0.8)
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print("saved:", out_pdf)

def print_winrate_heldout(gap_df, groups):
    lang_names = {'en': 'English', 'pl': 'Polish', 'ru': 'Russian', 'zh': 'Chinese', 'overall': 'Overall'}
    gap_df = gap_df.copy()
    gap_df["formatted"] = gap_df.apply(
        lambda row: f"{row['wr_test_mean']:.4f} ± {row['wr_test_se']:.4f}", axis=1
    )
    table = gap_df.pivot(index="group", columns="rho", values="formatted")
    all_groups = groups + ["overall"]
    table = table.reindex(all_groups)
    table.index = [lang_names.get(grp, grp) for grp in table.index]
    display(table)
    return table

In [40]:
print("Winrate on held-out data:")
print_winrate_heldout(gap_df, groups)
plot_winrate_heldout(gap_df, groups)

Winrate on held-out data:


rho,0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0
White,0.4514 ± 0.0033,0.4544 ± 0.0025,0.4531 ± 0.0022,0.4496 ± 0.0021,0.4430 ± 0.0022,0.4364 ± 0.0022,0.4282 ± 0.0023,0.4195 ± 0.0023,0.4146 ± 0.0024,0.4134 ± 0.0024,0.4130 ± 0.0024
Black,0.2909 ± 0.0061,0.3280 ± 0.0053,0.3528 ± 0.0047,0.3638 ± 0.0044,0.3704 ± 0.0041,0.3789 ± 0.0039,0.3854 ± 0.0036,0.3896 ± 0.0034,0.3934 ± 0.0031,0.3951 ± 0.0031,0.3955 ± 0.0031
Asian,0.3636 ± 0.0056,0.3957 ± 0.0042,0.4099 ± 0.0034,0.4125 ± 0.0032,0.4096 ± 0.0030,0.4081 ± 0.0031,0.4039 ± 0.0030,0.4004 ± 0.0030,0.3969 ± 0.0030,0.3958 ± 0.0030,0.3956 ± 0.0030
Mixed,0.3097 ± 0.0055,0.3304 ± 0.0043,0.3451 ± 0.0038,0.3532 ± 0.0037,0.3600 ± 0.0036,0.3669 ± 0.0034,0.3730 ± 0.0034,0.3780 ± 0.0032,0.3810 ± 0.0032,0.3828 ± 0.0031,0.3830 ± 0.0031
Overall,0.4669 ± 0.0025,0.4705 ± 0.0018,0.4698 ± 0.0014,0.4664 ± 0.0014,0.4610 ± 0.0015,0.4553 ± 0.0016,0.4492 ± 0.0016,0.4432 ± 0.0018,0.4391 ± 0.0018,0.4378 ± 0.0018,0.4373 ± 0.0018


saved: winrate_heldout_ethnic_group_ALPHA=1.0.pdf


In [41]:
def plot_generalization(gap_df, out_pdf=f"generalization_GROUP={GROUP_BY}_ALPHA={ALPHA}.pdf"):
    sns.set_style("whitegrid")
    
    plt.rcParams.update({
        "font.size": 11,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.5,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0.6,
    })
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    # Plot overall generalization gap
    d = gap_df[gap_df["group"] == "overall"].sort_values("rho")
    if len(d) > 0:
        x = d["rho"].to_numpy(float)
        y = d["wr_gap_mean"].to_numpy(float) * 100.0
        se = d["wr_gap_se"].to_numpy(float) * 100.0
        ax.plot(x, y, marker="o", linewidth=1.2, markersize=5,
               label="Overall", linestyle="--", color='#2C3E50', alpha=0.9)
        ax.fill_between(x, y - se, y + se, alpha=0.15, color='#2C3E50')
    
    ax.set_xlabel(r"$\rho$", fontsize=12, labelpad=6)
    ax.set_ylabel("Win rate generalization gap (%)", fontsize=12, labelpad=6)
    
    ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.4)
    ax.set_axisbelow(True)
    
    ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8, alpha=0.5)
    
    fig.tight_layout()
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print("saved:", out_pdf)

plot_generalization(gap_df)

saved: generalization_GROUP=ethnic_group_ALPHA=1.0.pdf


In [42]:
def print_boot_p(boot_p, models, rhos, thresh=0.02):
    rows = []
    for rho in rhos:
        P = np.vstack(boot_p[rho])
        mean = P.mean(axis=0)
        se = P.std(axis=0) / np.sqrt(P.shape[0])
        for i, name in enumerate(models):
            rows.append({"rho": rho, "model": name, "mean": mean[i], "se": se[i]})
    wdf = pd.DataFrame(rows)
    keep = wdf.groupby("model")["mean"].max()
    keep = keep[keep > thresh].index
    wdf = wdf[wdf["model"].isin(keep)].copy()
    wdf["mean%"] = (100*wdf["mean"]).round(2)
    wdf["se%"] = (100*wdf["se"]).round(2)
    
    wdf["formatted"] = wdf.apply(lambda row: f"{row['mean%']:.2f} ± {row['se%']:.2f}", axis=1)
    
    table = wdf.pivot(index="model", columns="rho", values="formatted")
    
    rho0_means = wdf[wdf["rho"] == rhos[0]].set_index("model")["mean"].sort_values(ascending=False)
    table = table.reindex(rho0_means.index)
    
    print(f"(shown if mean> {100*thresh:.1f}% for some rho)")
    display(table)
    return wdf

wdf = print_boot_p(boot_p, MODELS, rhos, thresh=0.02)


(shown if mean> 2.0% for some rho)


rho,0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
google/gemini-2.5-pro,71.78 ± 2.25,63.55 ± 1.37,58.03 ± 0.90,52.60 ± 0.79,47.43 ± 0.77,42.07 ± 0.79,38.05 ± 0.79,34.10 ± 0.78,31.93 ± 0.77,31.05 ± 0.73,30.76 ± 0.72
x-ai/grok-4,16.41 ± 1.53,12.89 ± 1.03,10.17 ± 0.67,10.33 ± 0.66,10.33 ± 0.64,11.11 ± 0.68,10.77 ± 0.66,10.48 ± 0.67,10.59 ± 0.68,10.62 ± 0.69,10.79 ± 0.70
x-ai/grok-3,11.33 ± 1.54,22.74 ± 1.05,29.80 ± 0.73,31.05 ± 0.68,31.61 ± 0.69,31.92 ± 0.73,31.48 ± 0.77,31.43 ± 0.76,30.79 ± 0.76,30.31 ± 0.71,30.07 ± 0.69
anthropic/claude-opus-4,0.38 ± 0.17,0.71 ± 0.21,0.85 ± 0.26,2.78 ± 0.47,5.52 ± 0.72,7.71 ± 0.80,10.03 ± 0.85,12.17 ± 0.86,13.67 ± 0.86,14.66 ± 0.87,14.90 ± 0.87
google/gemini-2.5-flash,0.08 ± 0.05,0.07 ± 0.06,0.91 ± 0.21,2.17 ± 0.31,2.51 ± 0.37,2.22 ± 0.34,2.08 ± 0.30,2.05 ± 0.31,2.04 ± 0.32,2.15 ± 0.33,2.24 ± 0.33
deepseek/deepseek-r1-0528,0.03 ± 0.03,0.05 ± 0.03,0.23 ± 0.08,0.95 ± 0.18,2.14 ± 0.25,3.88 ± 0.35,5.96 ± 0.46,7.01 ± 0.51,7.74 ± 0.56,7.54 ± 0.57,7.52 ± 0.58
anthropic/claude-sonnet-4,0.00 ± 0.00,0.00 ± 0.00,0.01 ± 0.01,0.12 ± 0.05,0.40 ± 0.12,0.84 ± 0.19,1.25 ± 0.25,1.93 ± 0.35,2.07 ± 0.35,2.34 ± 0.38,2.38 ± 0.38


In [43]:
def plot_model_performance_grid(wdf, model_shortcuts=None, model_colors=None, out_pdf=None):
    if out_pdf is None:
        out_pdf = f"model_performance_grid_{GROUP_BY}_ALPHA={ALPHA}.pdf"
    
    sns.set_style("white")
    
    plt.rcParams.update({
        "font.size": 10,
        "axes.labelsize": 11,
        "xtick.labelsize": 9,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.4,
        "xtick.major.width": 0.6,
        "ytick.major.width": 0,
    })
    
    default_shortcuts = {
        'x-ai/grok-3': 'Grok 3',
        'google/gemini-2.5-pro': 'Gemini 2.5 Pro',
        'x-ai/grok-4': 'Grok 4',
        'google/gemini-2.5-flash': 'Gemini 2.5 Flash',
        'deepseek/deepseek-r1-0528': 'DeepSeek R1',
        'anthropic/claude-opus-4': 'Claude Opus 4',
        'anthropic/claude-sonnet-4': 'Claude Sonnet 4',
    }
    
    default_colors = {
        'x-ai/grok-3': '#E74C3C',
        'google/gemini-2.5-pro': '#3498DB',
        'x-ai/grok-4': '#F39C12',
        'google/gemini-2.5-flash': '#27AE60',
        'deepseek/deepseek-r1-0528': '#9B59B6',
        'anthropic/claude-opus-4': '#1ABC9C',
        'anthropic/claude-sonnet-4': '#E67E22',
    }
    
    model_names = model_shortcuts if model_shortcuts is not None else default_shortcuts
    colors = model_colors if model_colors is not None else default_colors
    
    unique_rhos = sorted(wdf["rho"].unique())
    rhos_to_plot = [0.0, 0.4, 0.7, 1.0]
    
    fig, axes = plt.subplots(1, 4, figsize=(18, 5))
    
    models_ordered = wdf[np.isclose(wdf["rho"], 0.0)].sort_values("mean", ascending=True)["model"].values
    
    for idx, rho in enumerate(rhos_to_plot):
        ax = axes[idx]
        
        data = wdf[np.isclose(wdf["rho"], rho)].copy()
        data = data.set_index("model").reindex(models_ordered).reset_index()
        data = data.dropna()
        
        bar_spacing = 0.65
        y_pos = np.arange(len(data)) * bar_spacing
        bar_colors = [colors.get(m, '#95A5A6') for m in data["model"]]
        
        ax.barh(y_pos, data["mean%"], xerr=data["se%"], 
               color=bar_colors,
               error_kw={'elinewidth': 1.2, 'capsize': 2.5, 'alpha': 0.7},
               height=0.5 * bar_spacing, alpha=0.85, edgecolor='white', linewidth=0.5)
        
        ax.set_xlim(0, 75)
        ax.set_xticks([0, 20, 40, 60, 80])
        
        ax.set_axisbelow(True)
        ax.grid(True, axis='x', alpha=0.25, linestyle='-', linewidth=0.5, color='gray')
        ax.grid(True, axis='y', alpha=0.15, linestyle='-', linewidth=0.3, color='gray')
        

        ax.set_yticks(y_pos)
        if idx == 0:
            clean_names = [model_names.get(m, m) for m in data["model"]]
            ax.set_yticklabels(clean_names, fontsize=12)
        else:
            ax.set_yticklabels([]) 
            
        ax.set_xlabel("Probability (%)", fontsize=14, labelpad=4)
        ax.set_xticklabels(['0', '20', '40', '60', '80'], fontsize=14)
        
        rho_label = f'ρ = {rho:.1f}'
        ax.text(0.95, 0.1, rho_label, transform=ax.transAxes,
               fontsize=16, fontweight='bold', ha='right', va='bottom',
               bbox=dict(boxstyle='round,pad=0.4', facecolor='white', 
                        edgecolor='lightgray', alpha=0.8, linewidth=0.5))
        
        ax.spines['left'].set_linewidth(0.8)
        ax.spines['bottom'].set_linewidth(0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    fig.tight_layout(pad=1.2)
    fig.savefig(out_pdf, bbox_inches="tight", dpi=300)
    plt.close(fig)
    print(f"saved: {out_pdf}")

In [44]:
plot_model_performance_grid(wdf)

saved: model_performance_grid_ethnic_group_ALPHA=1.0.pdf


In [45]:
def build_margins(df, groups, models, alpha=0.0, group_col="group"):
    m = len(models); K = len(groups)
    idx = {name:i for i,name in enumerate(models)}
    df = df[df[group_col].isin(groups)].copy()
    df = df[df["model_a"].isin(models) & df["model_b"].isin(models)].copy()
    counts = df[group_col].value_counts().reindex(groups).fillna(0).astype(int)
    tot = int(counts.sum())
    if tot == 0: raise ValueError("No rows after filtering; check groups/models.")
    w0 = (counts/tot).to_numpy(float)
    
    df = df.copy()
    df["idx_a"] = df["model_a"].map(idx)
    df["idx_b"] = df["model_b"].map(idx)
    df = df.dropna(subset=["idx_a", "idx_b"]).copy()
    
    M_list = []
    T_list = [] 
    
    for grp in groups:
        rows = df[df[group_col]==grp]
        if len(rows) == 0:
            M_list.append(np.zeros((m,m), float))
            T_list.append(np.zeros((m,m), float))  
            continue
        
        win = np.zeros((m,m), float)
        idx_a = rows["idx_a"].values.astype(int)
        idx_b = rows["idx_b"].values.astype(int)
        choice = rows["choice"].values
        
        mask_a = (choice == "A")
        np.add.at(win, (idx_a[mask_a], idx_b[mask_a]), 1.0)
        mask_b = (choice == "B")
        np.add.at(win, (idx_b[mask_b], idx_a[mask_b]), 1.0)

        T = np.zeros((m, m), float)
        for i in range(m):
            for j in range(i+1, m):
                tot_ij = win[i,j] + win[j,i]
                T[i, j] = tot_ij
                T[j, i] = tot_ij
        T_list.append(T)
        
        M = np.zeros((m,m), float)
        for i in range(m):
            for j in range(i+1, m):
                tot_ij = win[i,j] + win[j,i]
                if tot_ij > 0:
                    mij = (win[i,j] - win[j,i]) / (tot_ij + 2.0*alpha)
                else:
                    mij = 0.0
                M[i,j] = mij; M[j,i] = -mij
        M_list.append(M)
    
    return M_list, T_list, w0, counts, df 

def top_k_opposite_sign_details(
    M_list: List[np.ndarray],
    T_list: List[np.ndarray],
    models: List[str],
    groups: List[str],
    k: int = 5,
    model_shortcuts: Dict[str, str] = None
):

    if model_shortcuts is None:
        model_shortcuts = {m: m for m in models}

    n = len(models)
    
    if len(M_list) != len(groups) or len(T_list) != len(groups):
        raise ValueError("M_list and T_list must have same length as groups")

    M_stack = np.stack(M_list, axis=0)
    T_stack = np.stack(T_list, axis=0)  

    Ts_total = T_stack.sum(axis=0) 
    pos_weights = np.where(M_stack > 0, T_stack, 0).sum(axis=0)
    neg_weights = np.where(M_stack < 0, T_stack, 0).sum(axis=0)
    
    with np.errstate(divide='ignore', invalid='ignore'):
        p_pos = pos_weights / Ts_total
        p_neg = neg_weights / Ts_total
        prob_weighted = 2.0 * p_pos * p_neg
    prob_weighted = np.nan_to_num(prob_weighted, nan=0.0)

    mask_upper = np.triu(np.ones((n, n), dtype=bool))
    candidate_mask = (~mask_upper) & (Ts_total > 0)  
    flat_idx = np.where(candidate_mask.ravel())[0]
    flat_probs = prob_weighted.ravel()[flat_idx]
    
    if flat_probs.size == 0:
        print("No valid lower-triangle pairs with comparisons found.")
        return

    top_k_idx = flat_idx[np.argsort(flat_probs)[::-1][:k]]
    rows, cols = np.unravel_index(top_k_idx, (n, n))

    print(f"\n{'='*80}")
    print(f"TOP {k} MODEL PAIRS WITH MOST INCONSISTENT RANKINGS ACROSS GROUPS")
    print(f"{'='*80}\n")

    for rank, (i, j) in enumerate(zip(rows, cols), start=1):
        prob_val = prob_weighted[i, j]
        total_counts_ij = Ts_total[i, j]
        
        pos_grp_idx = np.where(M_stack[:, i, j] > 0)[0]
        neg_grp_idx = np.where(M_stack[:, i, j] < 0)[0]
        
        pos_groups_list = [groups[g] for g in pos_grp_idx]
        neg_groups_list = [groups[g] for g in neg_grp_idx]
        
        pos_counts_by_grp = {groups[g]: int(T_stack[g, i, j]) for g in pos_grp_idx}
        neg_counts_by_grp = {groups[g]: int(T_stack[g, i, j]) for g in neg_grp_idx}

        model_i_name = model_shortcuts.get(models[i], models[i])
        model_j_name = model_shortcuts.get(models[j], models[j])

        print(f"Rank {rank}: {model_i_name} vs {model_j_name}")
        print(f"  Inconsistency score (weighted): {prob_val:.3f}")
        print(f"  Total comparisons across groups: {int(total_counts_ij)}")
        print(f"\n  Groups where '{model_i_name}' beats '{model_j_name}' ({len(pos_groups_list)}):")
        if pos_groups_list:
            for grp in pos_groups_list:
                print(f"    - {grp}: {pos_counts_by_grp[grp]} comparisons")
        else:
            print("    (none)")
        
        print(f"\n  Groups where '{model_j_name}' beats '{model_i_name}' ({len(neg_groups_list)}):")
        if neg_groups_list:
            for grp in neg_groups_list:
                print(f"    - {grp}: {neg_counts_by_grp[grp]} comparisons")
        else:
            print("    (none)")
        print("\n" + "-" * 80 + "\n")


df = load_df(group_by=GROUP_BY)
M_list, T_list, w0, counts, df_filtered = build_margins(df, groups, MODELS, alpha=1.0, group_col="group")

model_shortcuts = {
    'anthropic/claude-3.7-sonnet': 'Claude 3.7 Sonnet',
    'anthropic/claude-opus-4': 'Claude Opus 4',
    'anthropic/claude-sonnet-4': 'Claude Sonnet 4',
    'cohere/command-a': 'Command A',
    'cohere/command-r7b-12-2024': 'Command R7B',
    'deepseek/deepseek-r1-0528': 'DeepSeek R1',
    'google/gemini-2.0-flash-001': 'Gemini 2.0 Flash',
    'google/gemini-2.5-flash': 'Gemini 2.5 Flash',
    'google/gemini-2.5-pro': 'Gemini 2.5 Pro',
    'mistralai/mistral-nemo': 'Mistral Nemo',
    'openai/o1': 'o1',
    'openai/o3-mini': 'o3 mini',
    'openai/o4-mini': 'o4 mini',
    'x-ai/grok-3': 'Grok 3',
    'x-ai/grok-4': 'Grok 4',
}

top_k_opposite_sign_details(
    M_list=M_list,
    T_list=T_list,
    models=MODELS,
    groups=groups,
    k=10,
    model_shortcuts=model_shortcuts  # optional
)


TOP 10 MODEL PAIRS WITH MOST INCONSISTENT RANKINGS ACROSS GROUPS

Rank 1: Grok 3 vs Gemini 2.5 Flash
  Inconsistency score (weighted): 0.475
  Total comparisons across groups: 314

  Groups where 'Grok 3' beats 'Gemini 2.5 Flash' (2):
    - Black: 43 comparisons
    - Asian: 79 comparisons

  Groups where 'Gemini 2.5 Flash' beats 'Grok 3' (2):
    - White: 179 comparisons
    - Mixed: 13 comparisons

--------------------------------------------------------------------------------

Rank 2: o1 vs Command A
  Inconsistency score (weighted): 0.469
  Total comparisons across groups: 194

  Groups where 'o1' beats 'Command A' (2):
    - White: 120 comparisons
    - Mixed: 1 comparisons

  Groups where 'Command A' beats 'o1' (2):
    - Black: 30 comparisons
    - Asian: 43 comparisons

--------------------------------------------------------------------------------

Rank 3: Gemini 2.0 Flash vs Claude Sonnet 4
  Inconsistency score (weighted): 0.455
  Total comparisons across groups: 186

  G

In [46]:
from typing import Dict, List, Tuple
def plot_heatmap_inconsistency(
    M_list: List[np.ndarray],
    T_list: List[np.ndarray],
    models: List[str],
    groups: List[str],
    model_shortcuts: Dict[str, str] = None,
    out_pdf: str = None,
    figsize: Tuple[int, int] = (12, 10),
    annot_fontsize: int = 10,
    cmap: str = "Blues",
):

    if out_pdf is None:
        out_pdf = f"heatmap_inconsistency_{GROUP_BY}_ALPHA={ALPHA}.pdf"
    
    if model_shortcuts is None:
        model_shortcuts = {m: m for m in models}
    
    n = len(models)
    
    M_stack = np.stack(M_list, axis=0)   
    T_stack = np.stack(T_list, axis=0)   

    Ts_total = T_stack.sum(axis=0)
    pos_weights = np.where(M_stack > 0, T_stack, 0).sum(axis=0)
    neg_weights = np.where(M_stack < 0, T_stack, 0).sum(axis=0)
    
    with np.errstate(divide='ignore', invalid='ignore'):
        p_pos = pos_weights / Ts_total
        p_neg = neg_weights / Ts_total
        prob_weighted = 2.0 * p_pos * p_neg
    prob_weighted = np.nan_to_num(prob_weighted, nan=0.0)

    W_overall = np.zeros((n, n), float)
    T_overall = Ts_total.copy()
    for k in range(len(M_list)):
        W_overall += M_list[k] * T_list[k]
    
    winrates = []
    for i in range(n):
        total_wins = W_overall[i, :].sum()
        total_comps = T_overall[i, :].sum()
        winrates.append(total_wins / total_comps if total_comps > 0 else 0.0)
    
    sorted_indices = sorted(range(n), key=lambda i: (winrates[i], i), reverse=True)
    sorted_models = [models[i] for i in sorted_indices]

    prob_sorted = prob_weighted[np.ix_(sorted_indices, sorted_indices)]
    T_sorted = T_overall[np.ix_(sorted_indices, sorted_indices)]
    
    mask = np.triu(np.ones((n, n), dtype=bool))
    
    annot = np.full(prob_sorted.shape, "", dtype=object)
    eps = 1e-12
    for i in range(n):
        for j in range(n):
            if not mask[i, j] and T_sorted[i, j] > 0 and prob_sorted[i, j] > eps:
                annot[i, j] = f"{prob_sorted[i, j]:.2f}"
    

    display_names = [model_shortcuts.get(model, model) for model in sorted_models]

    plt.rcParams.update({
        "font.size": 10,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "sans-serif",
    })
    
    fig, ax = plt.subplots(figsize=figsize)
    
    global_vmax = prob_sorted.max() if prob_sorted.max() > 0 else 1.0
    
    sns.heatmap(
        prob_sorted,
        mask=mask,
        xticklabels=display_names,
        yticklabels=display_names,
        cmap=cmap,
        vmin=0.0,
        vmax=global_vmax,
        square=True,
        cbar = False,
        annot=annot,
        fmt='',
        annot_kws={'fontsize': annot_fontsize},
        ax=ax
    )
    
    ax.set_xlabel('Loser', fontsize=14)
    ax.set_ylabel('Winner', fontsize=14)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    plt.savefig(out_pdf, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"saved: {out_pdf}")


# Usage:
plot_heatmap_inconsistency(
    M_list=M_list,
    T_list=T_list,
    models=MODELS,
    groups=groups,
    model_shortcuts=model_shortcuts,
    out_pdf=f"heatmap_inconsistency_{GROUP_BY}_ALPHA={ALPHA}.pdf"
)

saved: heatmap_inconsistency_ethnic_group_ALPHA=1.0.pdf
