In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import Levenshtein
import pandas as pd
import plotly.express as px
import torch
from scipy.stats import kendalltau, pearsonr, spearmanr
from tqdm.auto import tqdm

tqdm.pandas()

from transformers import AutoModelForCausalLM, AutoTokenizer

from pws_in_context.constants import PROJECT_ROOT

top_k = 20
DATA_PATH = PROJECT_ROOT / "data"

In [None]:
# =========================================
# Retrieve data and transform dataframes
# =========================================
association_data = "association_data"

data_after_path = DATA_PATH / "after_pws_in_context_data" / association_data
data_before_path = DATA_PATH / "before_pws_in_context_data" / association_data


df_after = pd.concat([pd.read_csv(f) for f in data_after_path.iterdir()], axis=0).drop(
    columns=["trial", "PROLIFIC_PID"]
)
df_after["context"] = df_after.apply(lambda row: row.sentence.split(f" {row.target}")[0], axis=1)

df_before = pd.concat([pd.read_csv(f) for f in data_before_path.iterdir()], axis=0).drop(
    columns=["trial", "PROLIFIC_PID"]
)
df_before["context"] = df_before.apply(lambda row: row.sentence.split(f" {row.target}")[0], axis=1)

In [None]:
# ================================
# Custom supporting functions
# ================================


def top_k_predictions_for_context(
    context: str, target: str, k: int = 5, tokenizer=None, model=None
) -> float:
    assert tokenizer is not None and model is not None, "Provide tokenizer and model"

    device = next(model.parameters()).device

    inputs = tokenizer(context, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    last_logits = outputs.logits[:, -1, :]
    _, topk_indices = torch.topk(last_logits, k=k, dim=-1)
    topk_ids = topk_indices[0].tolist()

    lev_dist_list = []

    for tok_id in topk_ids:
        tok_str = tokenizer.decode([tok_id], skip_special_tokens=True)
        tok_str_clean = tok_str.strip()
        if tok_str_clean != "" and tok_str_clean.isalpha():
            levenshtein = Levenshtein.distance(tok_str_clean, target) / max(
                len(tok_str_clean), len(target)
            )
            lev_dist_list.append(levenshtein)

    return sum(lev_dist_list) / len(lev_dist_list)


def normalized_levenshtein_dist(target: str, associations: list[str]) -> float | None:
    lev_dist_list = []

    for association in associations:
        if isinstance(association, str):
            levenshtein = Levenshtein.distance(association, target) / max(
                len(association), len(target)
            )
            lev_dist_list.append(levenshtein)

    if len(lev_dist_list) > 0:
        return sum(lev_dist_list) / len(lev_dist_list)
    else:
        return None


def calculate_corr(df: pd.DataFrame, col1: str, col2: str) -> dict:
    df = df[[col1, col2]].dropna()
    pearson_r, pearson_p = pearsonr(df[col1].tolist(), df[col2].tolist())
    spearman_r, spearman_p = spearmanr(df[col1].tolist(), df[col2].tolist())
    kendall_tau, kendall_p = kendalltau(df[col1].tolist(), df[col2].tolist())
    return {
        "pearson_r": (pearson_r.item(), pearson_p.item()),
        "spearman_r": (spearman_r.item(), spearman_p.item()),
        "kendall_tau": (kendall_tau.item(), kendall_p.item()),
    }


def plot_scatter_with_corrs(
    df: pd.DataFrame,
    col1: str,
    col2: str,
    title: str,
    cors: dict,
    range: bool = False,
    path: Path = None,
    filename: str = None,
) -> None:
    p_r, p_a = cors["pearson_r"][0], cors["pearson_r"][1]
    s_r, s_a = cors["spearman_r"][0], cors["spearman_r"][1]
    k_tau, k_a = cors["kendall_tau"][0], cors["kendall_tau"][1]

    fig = px.scatter(
        df,
        x="human_lev_dist",
        y="pred_lev_dist",
        color="class",
        width=1000,
        height=600,
        title=title,
        subtitle=f"Pearson: {p_r:.2f}, p-value < 0.05: {p_a < 0.05} | "
        f"Spearman: {s_r:.2f}, p-value < 0.05: {s_a < 0.05} | "
        f"Kendall: {k_tau:.2f}, p-value < 0.05: {k_a < 0.05}\n",
        labels={
            "human_lev_dist": "Levenshtein distance (human correlations - target)",
            "pred_lev_dist": "Levenshtein distance (LLM predictions - target)",
        },
        range_x=[0, 1] if range else None,
        range_y=[0, 1] if range else None,
    )

    fig.update_layout(
        title={
            "text": fig.layout.title.text,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
            "font": {"size": 24},
        },
        margin={"t": 150},
    )

    fig.show()

    fig.write_html(path / f"{filename}.html")
    fig.write_image(path / f"{filename}.png")

In [None]:
# ================================
# Prepare model and tokenizer
# ================================
model_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16, device_map="auto")

model.eval()

In [None]:
# Calculate normalized levenshtein distance on LLM predictions and human data (association task after condition)
df_after["pred_lev_dist"] = df_after.progress_apply(
    lambda row: top_k_predictions_for_context(
        context=row.context, target=row.target, k=top_k, tokenizer=tokenizer, model=model
    ),
    axis=1,
)

association_cols_after = list(filter(lambda x: "association" in x, df_after.columns.tolist()))

df_after["human_lev_dist"] = df_after.progress_apply(
    lambda row: normalized_levenshtein_dist(
        target=row.target,
        associations=row[association_cols_after],
    ),
    axis=1,
)

In [None]:
# ===================================================================================================================
# Calculate normalized levenshtein distance on LLM predictions and human data (association task before condition)
# ===================================================================================================================
df_before["pred_lev_dist"] = df_before.progress_apply(
    lambda row: top_k_predictions_for_context(
        context=row.context, target=row.target, k=top_k, tokenizer=tokenizer, model=model
    ),
    axis=1,
)

association_cols_before = list(filter(lambda x: "association" in x, df_before.columns.tolist()))

df_before["human_lev_dist"] = df_before.progress_apply(
    lambda row: normalized_levenshtein_dist(
        target=row.target,
        associations=row[association_cols_before],
    ),
    axis=1,
)

In [None]:
# ===========================================
# Save dataframes for further analysis
# ===========================================
LEVENSHTEIN_DATA_PATH = DATA_PATH / "levenshtein_dist"
os.makedirs(LEVENSHTEIN_DATA_PATH, exist_ok=True)

In [None]:
df_before.to_csv(LEVENSHTEIN_DATA_PATH / f"lev_before_k={top_k}.csv", index=False)

df_after.to_csv(LEVENSHTEIN_DATA_PATH / f"lev_after_k={top_k}.csv", index=False)

In [None]:
df_before = pd.read_csv(LEVENSHTEIN_DATA_PATH / f"lev_before_k={top_k}.csv")
df_after = pd.read_csv(LEVENSHTEIN_DATA_PATH / f"lev_after_k={top_k}.csv")

In [None]:
# ==============================
# Retrieve 'classes' of targets
# ==============================
targets_df = pd.read_csv(DATA_PATH / "new_targets.csv")


def class_category(label):
    if "low" in label:
        return "Low Prevalence Words"
    elif "high" in label:
        return "High Prevalence Words"
    else:
        return "Pseudowords"


class_map = {
    item.strip("*"): class_category(col)
    for col in targets_df.columns
    for item in targets_df[col].dropna()
}

df_before["class"] = df_before["target"].apply(lambda x: class_map[x])

df_after["class"] = df_after["target"].apply(lambda x: class_map[x])

In [None]:
GRAPHS_PATH = PROJECT_ROOT / "graphs"
os.makedirs(GRAPHS_PATH, exist_ok=True)

In [None]:
# =========================
# Plot after condition
# =========================
plot_scatter_with_corrs(
    df=df_after,
    col1="human_lev_dist",
    col2="pred_lev_dist",
    cors=calculate_corr(df_after, "human_lev_dist", "pred_lev_dist"),
    title="Levenshtein distance (Human vs LLM against target) (after condition)",
    range=False,
    path=GRAPHS_PATH,
    filename=f"lev_dist_after_k={top_k}",
)

In [None]:
# =========================
# Plot before condition
# =========================
plot_scatter_with_corrs(
    df=df_before,
    col1="human_lev_dist",
    col2="pred_lev_dist",
    cors=calculate_corr(df_before, "human_lev_dist", "pred_lev_dist"),
    title="Levenshtein distance (Human vs LLM against target) (before condition)",
    range=False,
    path=GRAPHS_PATH,
    filename=f"lev_dist_before_k={top_k}",
)