# Imports

In [2]:
import os
from typing import cast
import pickle
from collections import defaultdict

from datasets import load_dataset, Dataset

from delphi.constants import STATIC_ASSETS_DIR
from delphi.eval import utils
from delphi.eval import constants
from delphi.eval.vis_per_token_model import visualize_per_token_category

# from delphi.eval.calc_model_group_stats import calc_model_group_stats
from delphi.eval.spacy_token_labelling import TOKEN_LABELS

# Data

In [3]:
# load data
tokenized_corpus_dataset = cast(Dataset, load_dataset(
    constants.tokenized_corpus_dataset,
    split="validation"
))

# TODO: convert to use static paths
# with open("../src/delphi/eval/labelled_token_ids_dict.pkl", "rb") as f:
#     token_groups = pickle.load(f)
# model_group_stats = calc_model_group_stats(
#     tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()
# )
with open(f"{STATIC_ASSETS_DIR}/model_group_stats.pkl", "rb") as f:
    model_group_stats = pickle.load(f)

logprob_datasets = utils.load_logprob_datasets("validation")


# Visualization

In [4]:
performance_data = defaultdict(dict)
for model in constants.LLAMA2_MODELS:
    for token_group_desc in TOKEN_LABELS:
        if (model, token_group_desc) not in model_group_stats:
            continue
        stats = model_group_stats[(model, token_group_desc)]
        performance_data[model][token_group_desc] = (
            -stats["median"],
            -stats["75th"],
            -stats["25th"],
        )

visualize_per_token_category(
    performance_data,
    log_scale=True,
    bg_color="LightGrey",
    line_color="Red",
    marker_color="Orange",
    bar_color="Green",
)

VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…