In [None]:
%load_ext autoreload
%autoreload 2

import Levenshtein
import pandas as pd
import torch
from tqdm.auto import tqdm

tqdm.pandas()

from transformers import AutoModelForCausalLM, AutoTokenizer

from constants import PROJECT_ROOT

DATA_PATH = PROJECT_ROOT / "data"

In [None]:
association_data = "association_data"

data_after_path = DATA_PATH / "big_pilot_at_after_processed_data" / association_data
data_before_path = DATA_PATH / "big_pilot_at_before_processed_data" / association_data


df_after = pd.concat([pd.read_csv(f) for f in data_after_path.iterdir()], axis=0).drop(
    columns=["trial", "PROLIFIC_PID"]
)

df_after["context"] = df_after.apply(lambda row: row.sentence.split(f" {row.target}")[0], axis=1)

In [None]:
model_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float16, device_map="auto")

model.eval()

In [None]:
def top_k_predictions_for_context(
    context: str, target: str, k: int = 5, tokenizer=None, model=None
):
    assert tokenizer is not None and model is not None, "Provide tokenizer and model"

    device = next(model.parameters()).device

    inputs = tokenizer(context, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    last_logits = outputs.logits[:, -1, :]
    _, topk_indices = torch.topk(last_logits, k=k, dim=-1)
    topk_ids = topk_indices[0].tolist()

    lev_dist_list = []

    for tok_id in topk_ids:
        tok_str = tokenizer.decode([tok_id], skip_special_tokens=True)
        tok_str_clean = tok_str.strip()
        if tok_str_clean != "":
            levenshtein = Levenshtein.distance(tok_str_clean, target) / max(
                len(tok_str_clean), len(target)
            )
            lev_dist_list.append(levenshtein)

    return sum(lev_dist_list) / len(lev_dist_list)

In [None]:
def normalized_levenshtein_dist(target: str, associations: list[str]):
    lev_dist_list = []

    for association in associations:
        if isinstance(association, str):
            levenshtein = Levenshtein.distance(association, target) / max(
                len(association), len(target)
            )
            lev_dist_list.append(levenshtein)

    if len(lev_dist_list) > 0:
        return sum(lev_dist_list) / len(lev_dist_list)
    else:
        return None

In [None]:
df_after["pred_lev_dist"] = df_after.progress_apply(
    lambda row: top_k_predictions_for_context(
        context=row.context, target=row.target, k=20, tokenizer=tokenizer, model=model
    ),
    axis=1,
)

In [None]:
association_cols_after = list(filter(lambda x: "association" in x, df_after.columns.tolist()))

df_after["human_lev_dist"] = df_after.progress_apply(
    lambda row: normalized_levenshtein_dist(
        target=row.target,
        associations=row[association_cols_after],
    ),
    axis=1,
)