In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from gensim.models import KeyedVectors
import gensim.downloader as api
from transformers import AutoTokenizer, AutoModel

In [None]:

OUT_DIR = "results_rq1"
os.makedirs(OUT_DIR, exist_ok=True)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

STATIC_EMBEDDING_NAMES = [
    "word2vec-google-news-300",   
    "glove-wiki-gigaword-300"     
]


BERT_MODEL = "bert-base-uncased"
MAX_TEMPLATES = None           
N_PERMUTATIONS = 10000          
HOLM_ALPHA = 0.05
BOOTSTRAP_N = 5000              
BOOTSTRAP_SEAT_N = 2000         
RUN_TEMPLATE_SENSITIVITY = True 

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

male_names = ["John", "Paul", "Mike", "Kevin", "Steve", "Greg", "Jeff", "Bill", "Frank", "George"]
female_names = ["Amy", "Joan", "Lisa", "Sarah", "Diana", "Kate", "Ann", "Donna", "Susan", "Carol"]

male_terms = ["male", "man", "boy", "brother", "he", "him", "his", "son"]
female_terms = ["female", "woman", "girl", "sister", "she", "her", "hers", "daughter"]

career_words = ["executive", "management", "professional", "corporation", "salary", "office", "business", "career"]
family_words = ["home", "parents", "children", "family", "cousins", "marriage", "wedding", "relatives"]

science_words = ["science", "technology", "physics", "chemistry", "einstein", "nasa", "experiment", "astronomy"]
arts_words    = ["poetry", "art", "shakespeare", "dance", "literature", "novel", "symphony", "drama"]
math_words    = ["math", "algebra", "geometry", "calculus", "equations", "computation", "numbers", "theorem"]

WEAT_TESTS = {
    "WEAT6_gender_names__career_family": {
        "X": male_names,
        "Y": female_names,
        "A": career_words,
        "B": family_words
    },
    "WEAT7_science_arts__male_female_terms": {
        "X": science_words,
        "Y": arts_words,
        "A": male_terms,
        "B": female_terms
    },
    "WEAT8_math_arts__male_female_terms": {
        "X": math_words,
        "Y": arts_words,
        "A": male_terms,
        "B": female_terms
    },
}

In [None]:

SEAT_TEMPLATES = [
    "This is {}.",
    "{} is a person.",
    "They talked about {}.",
    "I heard about {} yesterday.",
    "People often mention {}.",
    "The discussion was about {}.",
    "One individual is named {}.",
    "Here we consider {}.",
    "Someone referred to {}.",
    "Everyone knows {}."
]
if MAX_TEMPLATES is None:
    MAX_TEMPLATES = len(SEAT_TEMPLATES)

def cosine_sim(u, v):
    return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-12)

def weat_effect_size(X, Y, A, B, vecs):
    def s(w):
        return np.mean([cosine_sim(vecs[w], vecs[a]) for a in A]) - \
               np.mean([cosine_sim(vecs[w], vecs[b]) for b in B])
    Xs = np.array([s(x) for x in X])
    Ys = np.array([s(y) for y in Y])
    pooled_std = np.std(np.concatenate([Xs, Ys]), ddof=1)
    d = 0.0 if pooled_std == 0 else (Xs.mean() - Ys.mean()) / pooled_std
    return float(d), Xs, Ys

def permutation_test_weat(X, Y, A, B, vecs, n_perm=10000, seed=42):
    rng = np.random.default_rng(seed)
    d_obs, Xs, Ys = weat_effect_size(X, Y, A, B, vecs)
    concat = np.concatenate([Xs, Ys])
    nX = len(Xs)
    more_extreme = 0
    for _ in range(n_perm):
        rng.shuffle(concat)
        Xp = concat[:nX]
        Yp = concat[nX:]
        pooled_std = np.std(concat, ddof=1)
        d_perm = 0.0 if pooled_std == 0 else (Xp.mean() - Yp.mean()) / pooled_std
        if abs(d_perm) >= abs(d_obs):
            more_extreme += 1
    p = (more_extreme + 1) / (n_perm + 1)
    return float(d_obs), float(p)

def holm_bonferroni(pvals_dict, alpha=0.05):
    items = sorted(pvals_dict.items(), key=lambda x: x[1])
    m = len(items)
    out = {}
    for i, (name, p) in enumerate(items, 1):
        threshold = alpha / (m - i + 1)
        reject = p <= threshold
        out[name] = {"p_raw": p, "p_holm": min(1.0, p * (m - i + 1)), "reject": reject}
    return {k: out[k] for k in pvals_dict.keys()}

def bootstrap_corr(x, y, n_boot=5000, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(x))
    pears, spears = [], []
    for _ in range(n_boot):
        bs = rng.choice(idx, size=len(idx), replace=True)
        pr, _ = pearsonr(x[bs], y[bs])
        sr, _ = spearmanr(x[bs], y[bs])
        pears.append(pr)
        spears.append(sr)
    return (
        np.percentile(pears, [2.5, 97.5]).tolist(),
        np.percentile(spears, [2.5, 97.5]).tolist()
    )

def bootstrap_seat_ci(X_scores, Y_scores, n_boot=2000, seed=42):
    """
    Bootstrap CI for pooled SEAT d: resample indices over concatenated scores.
    """
    rng = np.random.default_rng(seed)
    X_all = np.concatenate(X_scores)
    Y_all = np.concatenate(Y_scores)
    n_total = len(X_all) + len(Y_all)
    all_scores = np.concatenate([X_all, Y_all])
    labels = np.array([0]*len(X_all) + [1]*len(Y_all))  # 0->X, 1->Y
    idx = np.arange(n_total)

    def d_from_idx(idxs):
        x_mask = labels[idxs] == 0
        y_mask = ~x_mask
        Xb = all_scores[idxs][x_mask]
        Yb = all_scores[idxs][y_mask]
        pooled_std = np.std(np.concatenate([Xb, Yb]), ddof=1)
        return 0.0 if pooled_std == 0 else (Xb.mean() - Yb.mean()) / pooled_std

    ds = []
    for _ in range(n_boot):
        bs_idx = rng.choice(idx, size=len(idx), replace=True)
        ds.append(d_from_idx(bs_idx))
    return np.percentile(ds, [2.5, 97.5]).tolist()


In [None]:

def load_static_embeddings(names):
    embeds = {}
    for n in names:
        try:
            print(f"↓ downloading/loading {n} ...")
            embeds[n] = api.load(n)  
        except Exception as e:
            print(f"⚠️ Could not load {n}: {e}")
    return embeds

def run_weat_on_static_embeddings(weat_tests, embeddings, n_perm=N_PERMUTATIONS, out_dir=OUT_DIR):
    if len(embeddings) == 0:
        print("⚠️ No static embeddings. Skipping WEAT.")
        return None

    all_rows, all_pvals = [], {}

    for emb_name, kv in embeddings.items():
        print(f"\n[WEAT] Using {emb_name}")
        def get_vec(w):
            lw = w.lower()
            if lw in kv:
                return kv[lw]
            if w in kv:
                return kv[w]
            return None

        for test_name, test in weat_tests.items():
            X = [w for w in test["X"] if get_vec(w) is not None]
            Y = [w for w in test["Y"] if get_vec(w) is not None]
            A = [w for w in test["A"] if get_vec(w) is not None]
            B = [w for w in test["B"] if get_vec(w) is not None]

            if not (len(X) and len(Y) and len(A) and len(B)):
                print(f"Skipping {test_name} for {emb_name}: missing words.")
                continue

            words = set(X + Y + A + B)
            vecs = {w: get_vec(w) for w in words}

            d, p = permutation_test_weat(X, Y, A, B, vecs, n_perm=n_perm, seed=SEED)
            all_rows.append({
                "embedding": emb_name,
                "test": test_name,
                "weat_d": d,
                "p_value": p,
                "n_perm": n_perm,
                "n_X": len(X), "n_Y": len(Y), "n_A": len(A), "n_B": len(B)
            })
            all_pvals[(emb_name, test_name)] = p

    if not all_rows:
        return None

    df = pd.DataFrame(all_rows)
    df.to_csv(os.path.join(out_dir, "weat_results_raw.csv"), index=False)

    pvals_dict = {f"{k[0]}::{k[1]}": v for k, v in all_pvals.items()}
    holm = holm_bonferroni(pvals_dict, alpha=HOLM_ALPHA)

    holm_rows = []
    for k, v in holm.items():
        emb, test = k.split("::")
        holm_rows.append({
            "embedding": emb,
            "test": test,
            "p_raw": v["p_raw"],
            "p_holm": v["p_holm"],
            f"reject_at_alpha={HOLM_ALPHA}": v["reject"]
        })
    df_holm = pd.DataFrame(holm_rows)
    df_merged = df.merge(df_holm, on=["embedding", "test"])
    df_merged.to_csv(os.path.join(out_dir, "weat_results_holm.csv"), index=False)

    print("\n[WEAT] Summary:")
    print(df_merged)
    return df_merged

def bert_encode_templates(words, model, tokenizer, templates, max_templates, device=DEVICE):
    model.eval()
    out = {}
    with torch.no_grad():
        for w in words:
            reps = []
            for template in templates[:max_templates]:
                sent = template.format(w)
                inp = tokenizer(sent, return_tensors="pt").to(device)
                outputs = model(**inp, output_hidden_states=True, return_dict=True)
                hs = outputs.hidden_states[1:]  
                cls = torch.stack([h[:, 0, :] for h in hs], dim=0).squeeze(1)  
                reps.append(cls)
            out[w] = torch.stack(reps, dim=0).mean(dim=0).cpu() 
    return out

def seat_layerwise_effect_size(X, Y, A, B, reps_dict):
    L = list(reps_dict.values())[0].shape[0]
    d_per_layer = []
    X_scores = [[] for _ in range(L)]
    Y_scores = [[] for _ in range(L)]

    A_means, B_means = [], []
    for l in range(L):
        A_stack = torch.stack([reps_dict[a][l] for a in A], dim=0)
        B_stack = torch.stack([reps_dict[b][l] for b in B], dim=0)
        A_means.append(A_stack.mean(dim=0))
        B_means.append(B_stack.mean(dim=0))

    for l in range(L):
        def s(word):
            v = reps_dict[word][l]
            return torch.nn.functional.cosine_similarity(v, A_means[l], dim=0).item() - \
                   torch.nn.functional.cosine_similarity(v, B_means[l], dim=0).item()
        Xs = np.array([s(x) for x in X])
        Ys = np.array([s(y) for y in Y])
        pooled_std = np.std(np.concatenate([Xs, Ys]), ddof=1)
        d = 0.0 if pooled_std == 0 else (Xs.mean() - Ys.mean()) / pooled_std
        d_per_layer.append(d)
        X_scores[l] = Xs
        Y_scores[l] = Ys

    return np.array(d_per_layer), X_scores, Y_scores

def permutation_test_seat_layerwise(X_scores, Y_scores, n_perm=10000, seed=42):
    rng = np.random.default_rng(seed)
    X_all = np.concatenate(X_scores)
    Y_all = np.concatenate(Y_scores)
    nX = len(X_all)
    pooled_std = np.std(np.concatenate([X_all, Y_all]), ddof=1)
    d_obs = (X_all.mean() - Y_all.mean()) / pooled_std if pooled_std > 0 else 0.0

    concat = np.concatenate([X_all, Y_all])
    more_extreme = 0
    for _ in range(n_perm):
        rng.shuffle(concat)
        Xp = concat[:nX]
        Yp = concat[nX:]
        pooled_std_p = np.std(concat, ddof=1)
        d_perm = (Xp.mean() - Yp.mean()) / pooled_std_p if pooled_std_p > 0 else 0.0
        if abs(d_perm) >= abs(d_obs):
            more_extreme += 1
    p = (more_extreme + 1) / (n_perm + 1)
    return float(d_obs), float(p)

def run_seat_on_bert(weat_tests, model_name=BERT_MODEL, templates=SEAT_TEMPLATES,
                     max_templates=MAX_TEMPLATES, n_perm=N_PERMUTATIONS, out_dir=OUT_DIR):
    print(f"\n[SEAT] Loading {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(DEVICE)

    rows, pvals = [], {}

    for test_name, test in weat_tests.items():
        X, Y, A, B = test["X"], test["Y"], test["A"], test["B"]
        words = list(set(X + Y + A + B))
        reps = bert_encode_templates(words, model, tokenizer, templates, max_templates, device=DEVICE)

        d_layers, X_scores, Y_scores = seat_layerwise_effect_size(X, Y, A, B, reps)
        d_mean = float(np.mean(d_layers))
        d_max  = float(np.max(np.abs(d_layers)) * np.sign(d_layers[np.argmax(np.abs(d_layers))]))

        d_pooled, p = permutation_test_seat_layerwise(X_scores, Y_scores, n_perm=n_perm, seed=SEED)
        d_ci = bootstrap_seat_ci(X_scores, Y_scores, n_boot=BOOTSTRAP_SEAT_N, seed=SEED)

        for l, d in enumerate(d_layers):
            rows.append({
                "model": model_name,
                "test": test_name,
                "layer": l + 1,
                "seat_d": float(d),
                "seat_d_mean": d_mean,
                "seat_d_max": d_max,
                "seat_d_pooled": d_pooled,
                "seat_d_pooled_ci_low": d_ci[0],
                "seat_d_pooled_ci_high": d_ci[1],
                "p_value": p,
                "n_perm": n_perm
            })
        pvals[(model_name, test_name)] = p

        plt.figure(figsize=(7, 4))
        plt.plot(range(1, len(d_layers)+1), d_layers, marker='o')
        plt.axhline(0, color='black', linewidth=1)
        plt.title(f"SEAT layer-wise effect size ({test_name})")
        plt.xlabel("Layer")
        plt.ylabel("Cohen's d")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"seat_{test_name}_layerwise.png"), dpi=200)
        plt.close()

    if not rows:
        return None, None

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(out_dir, "seat_results_raw.csv"), index=False)

    pvals_dict = {f"{k[0]}::{k[1]}": v for k, v in pvals.items()}
    holm = holm_bonferroni(pvals_dict, alpha=HOLM_ALPHA)

    holm_rows = []
    for k, v in holm.items():
        mname, tname = k.split("::")
        holm_rows.append({
            "model": mname,
            "test": tname,
            "p_raw": v["p_raw"],
            "p_holm": v["p_holm"],
            f"reject_at_alpha={HOLM_ALPHA}": v["reject"]
        })
    df_holm = pd.DataFrame(holm_rows)
    df_holm.to_csv(os.path.join(out_dir, "seat_results_holm.csv"), index=False)

    seat_summary = (
        df.groupby(["model", "test"])
          .agg(seat_d_mean=("seat_d_mean", "first"),
               seat_d_pooled=("seat_d_pooled", "first"),
               seat_d_pooled_ci_low=("seat_d_pooled_ci_low", "first"),
               seat_d_pooled_ci_high=("seat_d_pooled_ci_high", "first"),
               p_value=("p_value", "first"))
          .reset_index()
    )
    seat_summary = seat_summary.merge(df_holm, on=["model", "test"])
    seat_summary.to_csv(os.path.join(out_dir, "seat_results_summary.csv"), index=False)

    print("\n[SEAT] Summary:")
    print(seat_summary)
    return df, seat_summary

def run_seat_template_sensitivity(weat_tests, model_name=BERT_MODEL,
                                  templates=SEAT_TEMPLATES, out_dir=OUT_DIR):
    print("\n[SEAT] Template sensitivity analysis")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(DEVICE)

    rows = []
    for test_name, test in weat_tests.items():
        X, Y, A, B = test["X"], test["Y"], test["A"], test["B"]
        words = list(set(X + Y + A + B))
        for t in templates:
            reps = bert_encode_templates(words, model, tokenizer, [t], 1, device=DEVICE)
            d_layers, _, _ = seat_layerwise_effect_size(X, Y, A, B, reps)
            rows.append({
                "model": model_name,
                "test": test_name,
                "template": t,
                "seat_d_mean": float(np.mean(d_layers)),
                "seat_d_max": float(np.max(np.abs(d_layers)) * np.sign(d_layers[np.argmax(np.abs(d_layers))]))
            })

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(out_dir, "seat_template_sensitivity.csv"), index=False)

    for test_name in df["test"].unique():
        sub = df[df["test"] == test_name]
        plt.figure(figsize=(8, 4))
        plt.bar(range(len(sub)), sub["seat_d_mean"])
        plt.axhline(0, color='black', linewidth=1)
        plt.xticks(range(len(sub)), [f"T{i+1}" for i in range(len(sub))], rotation=45)
        plt.ylabel("Mean SEAT d (across layers)")
        plt.title(f"Template sensitivity: {test_name}")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"seat_template_sensitivity_{test_name}.png"), dpi=200)
        plt.close()

    print("Saved template sensitivity CSV/plots.")
    return df


In [None]:


def correlate_weat_seat(weat_df, seat_summary_df, model_name=BERT_MODEL,
                        out_dir=OUT_DIR, n_boot=BOOTSTRAP_N):
    if weat_df is None or seat_summary_df is None:
        print("⚠️ Missing WEAT or SEAT results; skipping correlation.")
        return None

    d_weat = (weat_df.groupby("test")["weat_d"].mean()).rename("weat_d_mean")
    seat_use = (seat_summary_df[seat_summary_df["model"] == model_name]
                .set_index("test")["seat_d_mean"])

    tests = sorted(set(d_weat.index) & set(seat_use.index))
    if len(tests) < 2:
        print("Not enough aligned tests to correlate.")
        return None

    x = d_weat.loc[tests].values
    y = seat_use.loc[tests].values

    pear_r, pear_p = pearsonr(x, y)
    spear_r, spear_p = spearmanr(x, y)

    try:
        pear_ci, spear_ci = bootstrap_corr(x, y, n_boot=n_boot, seed=SEED)
    except Exception:
        pear_ci, spear_ci = [np.nan, np.nan], [np.nan, np.nan]

    res = {
        "tests_used": tests,
        "pearson_r": float(pear_r), "pearson_p": float(pear_p),
        "pearson_95ci": pear_ci,
        "spearman_r": float(spear_r), "spearman_p": float(spear_p),
        "spearman_95ci": spear_ci,
        "n_tests": len(tests)
    }
    with open(os.path.join(out_dir, "weat_seat_correlation.json"), "w") as f:
        json.dump(res, f, indent=2)

    print("\n[WEAT vs SEAT] Correlation summary:")
    print(json.dumps(res, indent=2))
    return res


In [None]:


if __name__ == "__main__":
    print("DEVICE:", DEVICE)


    STATIC_EMBEDS = load_static_embeddings(STATIC_EMBEDDING_NAMES)


    weat_df = run_weat_on_static_embeddings(
        WEAT_TESTS,
        STATIC_EMBEDS,
        n_perm=N_PERMUTATIONS,
        out_dir=OUT_DIR
    )


    seat_raw_df, seat_summary_df = run_seat_on_bert(
        WEAT_TESTS,
        model_name=BERT_MODEL,
        templates=SEAT_TEMPLATES,
        max_templates=MAX_TEMPLATES,
        n_perm=N_PERMUTATIONS,
        out_dir=OUT_DIR
    )


    if RUN_TEMPLATE_SENSITIVITY:
        run_seat_template_sensitivity(
            WEAT_TESTS,
            model_name=BERT_MODEL,
            templates=SEAT_TEMPLATES,
            out_dir=OUT_DIR
        )


    correlate_weat_seat(weat_df, seat_summary_df, model_name=BERT_MODEL, out_dir=OUT_DIR)

    print("\n✅ Done. Artifacts saved under:", OUT_DIR)


DEVICE: cpu
↓ downloading/loading word2vec-google-news-300 ...
↓ downloading/loading glove-wiki-gigaword-300 ...

[WEAT] Using word2vec-google-news-300

[WEAT] Using glove-wiki-gigaword-300

[WEAT] Summary:
                  embedding                                   test    weat_d  \
0  word2vec-google-news-300      WEAT6_gender_names__career_family  1.634555   
1  word2vec-google-news-300  WEAT7_science_arts__male_female_terms  1.084427   
2  word2vec-google-news-300     WEAT8_math_arts__male_female_terms  1.112460   
3   glove-wiki-gigaword-300      WEAT6_gender_names__career_family  1.654423   
4   glove-wiki-gigaword-300  WEAT7_science_arts__male_female_terms  1.368223   
5   glove-wiki-gigaword-300     WEAT8_math_arts__male_female_terms  0.989790   

    p_value  n_perm  n_X  n_Y  n_A  n_B     p_raw    p_holm  \
0  0.000200   10000   10   10    8    8  0.000200  0.001000   
1  0.022798   10000    8    8    8    8  0.022798  0.045595   
2  0.018998   10000    8    8    8    8  0.

  pr, _ = pearsonr(x[bs], y[bs])
  sr, _ = spearmanr(x[bs], y[bs])



[WEAT vs SEAT] Correlation summary:
{
  "tests_used": [
    "WEAT6_gender_names__career_family",
    "WEAT7_science_arts__male_female_terms",
    "WEAT8_math_arts__male_female_terms"
  ],
  "pearson_r": 0.8577121023235507,
  "pearson_p": 0.3437704411122802,
  "pearson_95ci": [
    NaN,
    NaN
  ],
  "spearman_r": 0.5,
  "spearman_p": 0.6666666666666667,
  "spearman_95ci": [
    NaN,
    NaN
  ],
  "n_tests": 3
}

✅ Done. Artifacts saved under: results_rq1
