In [None]:
import os
import yaml
import ast

import pandas as pd
import transformers
from datasets import load_from_disk, load_dataset, Dataset

from tqdm.auto import tqdm

from ..utils import tokenize_and_align_labels

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [39]:
# experiment_name = "with_spec"
experiment_name = "without_spec"
# experiment_name = "without_spec_t1_vs_t2"

if experiment_name == "with_spec":
    dataset_name = "./training_datasets/training_data_all_with_spec.ds"

elif experiment_name == "without_spec":
    dataset_name = "./training_datasets/training_data_all_no_spec.ds"

elif experiment_name == "without_spec_t1_vs_t2":
    dataset_name = "training_t1_test_t2_no_spec.ds"

elif experiment_name == "without_spec_t2_vs_t1":
    dataset_name = "training_t2_test_t1_no_spec.ds"

elif experiment_name == "with_spec_t1_vs_t2":
    dataset_name = "training_t1_test_t2_with_spec.ds"

elif experiment_name == "with_spec_t2_vs_t1":
    dataset_name = "training_t2_test_t1_with_spec.ds"

else:
    raise ValueError("experiment_name must be one of the predefined options.")

run_path = "/home/gpucce/Repos/abstraction_ladders/acl_abstraction_ladders/primary_school_data/primary_school_experiments/multirun/2025-07-15/18-23-29"
model_name_or_path = None
runs = os.listdir(run_path)
for run in runs:
    with open(os.path.join(run_path, run, ".hydra", "config.yaml")) as f:
        config = yaml.safe_load(f)
    if config["args"]["dataset_name"].split('/')[-1] == dataset_name.split('/')[-1]:
        model_name_or_path = os.path.join(run_path, run)
        break

assert model_name_or_path is not None, f"Model not found for dataset {dataset_name}"

In [41]:
model = transformers.AutoModelForTokenClassification.from_pretrained(
    model_name_or_path,trust_remote_code=True,)
model.to("cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, trust_remote_code=True,)

In [None]:
# ds = load_from_disk(dataset_name)
# ref_ds = load_from_disk(dataset_name)
# if "validation" in ds:
#     ds = ds["validation"]
#     ref_ds = ref_ds["validation"]

df = pd.read_csv("word_ladders_cleaned.csv", sep="\t").loc[:, ["start", "ladder"]]
df.ladder = df.ladder.apply(lambda x: ast.literal_eval(x))
df.ladder = df.ladder.apply(lambda x: [l if l != "entita'" else "entità" for l in x])  # Fixing the apostrophe issue
df["full_list"] = df["ladder"]
df["label"] = df["full_list"].apply(lambda x: [model.config.label2id.get(word, "O") for word in x])
df["clean_list"] = df["full_list"]

ds = Dataset.from_pandas(df)
ref_ds = Dataset.from_pandas(df)

ds = ds.map(lambda x: tokenize_and_align_labels(x, tokenizer, label_to_id=model.config.label2id), batched=True)
new_ds = ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/5943 [00:00<?, ? examples/s]

In [43]:
def get_results(ds, ref_ds, idx, do_print=False):
    results = []
    out = model(**{i:j.to("cuda") for i, j in ds[[idx]].items()})
    preds = out.logits.cpu().argmax(-1)[0].tolist()[1:]
    tokens = ref_ds[idx]["full_list"]

    # tokenization check
    refs = ds[idx]["input_ids"][ds[idx]["attention_mask"].bool()].tolist()
    check = []
    for tok in tokens:
        check += tokenizer(tok, add_special_tokens=False)["input_ids"]
    assert refs[1:-1] == check

    count = 0
    for tok in tokens:
        tokenized = tokenizer(tok, add_special_tokens=False)["input_ids"]
        label = model.config.id2label[preds[count]]
        count += len(tokenized)
        results.append({"token": tok, "label": label,})
        if do_print:
            print("Token:", tok, "| Label:", label)

    return results  # Remove the first empty token


In [44]:
n_samples_to_check = 500 # len(ref_ds)
all_results = []
for idx in tqdm(range(n_samples_to_check)):
    result = get_results(ds, ref_ds, idx)
    all_results.append(result)


  0%|          | 0/500 [00:00<?, ?it/s]

In [45]:
total = 0
metrics = {"n": 0, "p": 0, "tp": 0, "tn": 0, "fp": 0, "fn": 0, "List Acc": 0, "List Acc @1": 0, "List Acc @2": 0}
for idx in tqdm(range(n_samples_to_check)):
    results = all_results[idx]
    full_list = ref_ds[idx]["full_list"]
    clean_results = ref_ds[idx]["clean_list"]
    total += len(full_list)
    metrics["p"] += len(clean_results)
    metrics["n"] += len(full_list) - len(clean_results)
    metrics["tp"] += sum(1 for r in results if r["label"] != "O" and r["token"] in clean_results)
    metrics["tn"] += sum(1 for r in results if r["label"] == "O" and r["token"] not in clean_results)
    metrics["fn"] += sum(1 for r in results if r["label"] == "O" and r["token"] in clean_results)
    metrics["fp"] += sum(1 for r in results if r["label"] != "O" and r["token"] not in clean_results)
    matches = [i["token"] for i in results if i["label"] != "O"]
    if matches == clean_results:
        metrics["List Acc"] += 1
    if sum(1 for i in matches if i not in clean_results) + sum(1 for i in clean_results if i not in matches) <= 1:
        metrics["List Acc @1"] += 1
    if sum(1 for i in matches if i not in clean_results) + sum(1 for i in clean_results if i not in matches) <= 2:
        metrics["List Acc @2"] += 1

assert total == (metrics["tp"] + metrics["tn"] + metrics["fp"] + metrics["fn"])
assert total == (metrics["p"] + metrics["n"])
assert metrics["p"] == (metrics["tp"] + metrics["fn"])
assert metrics["n"] == (metrics["tn"] + metrics["fp"])

  0%|          | 0/500 [00:00<?, ?it/s]

In [46]:
def print_metrics(metrics):
    total = metrics["tp"] + metrics["tn"] + metrics["fp"] + metrics["fn"]
    print("==== Word Level Metrics ====")
    print()
    print("Precision:".ljust(15), round(metrics["tp"] / (metrics["tp"] + metrics["fp"]), 3))
    print("Recall:".ljust(15), round(metrics["tp"] / (metrics["tp"] + metrics["fn"]), 3))
    print("F1:".ljust(15), round(2 * metrics["tp"] / (2 * metrics["tp"] + metrics["fp"] + metrics["fn"]), 3))
    print("Accuracy:".ljust(15), round((metrics["tp"] + metrics["tn"]) / (metrics["tp"] + metrics["tn"] + metrics["fp"] + metrics["fn"]), 3))
    print('-'*20)
    print("Positives:", metrics["p"], "Negatives:", metrics["n"], "Total:", total)
    print("Positive Ratio:", round(metrics["p"] / total, 3), "Negative Ratio:", round(metrics["n"] / total, 3))
    print()
    print("==== List Level Metrics ====")
    print()
    print("Accuracy:".ljust(15), round(metrics["List Acc"] / len(ref_ds), 3))
    print("Accuracy @1:".ljust(15), round(metrics["List Acc @1"] / len(ref_ds), 3))
    print("Accuracy @2:".ljust(15), round(metrics["List Acc @2"] / len(ref_ds), 3))
    print('-'*20)
    print("Total Lists:", len(ref_ds))



print_metrics(metrics)

==== Word Level Metrics ====

Precision:      1.0
Recall:         0.233
F1:             0.378
Accuracy:       0.233
--------------------
Positives: 3734 Negatives: 0 Total: 3734
Positive Ratio: 1.0 Negative Ratio: 0.0

==== List Level Metrics ====

Accuracy:       0.001
Accuracy @1:    0.004
Accuracy @2:    0.013
--------------------
Total Lists: 5943


In [47]:
def print_results(idx):
    print("Full List:".ljust(20), ref_ds[idx]["full_list"])
    # print("Cleaned List:".ljust(20), ref_ds[idx]["clean_list"])
    predicted_list = []
    for i in get_results(ds, ref_ds, idx):
        if i["label"] != "O":
            predicted_list.append(i["token"])
    print("Predicted List:".ljust(20), predicted_list)

In [125]:
print_results(269)

Full List:           ['fox', 'volpe', 'mammifero']
Predicted List:      []
