# The Important Part

In [2]:
from typing import cast
from dataclasses import dataclass
import pickle
import numpy as np
import pandas as pd
from collections import defaultdict

from datasets import load_dataset, Dataset

from delphi.eval import constants
from delphi.eval.vis_per_token_model import visualize_per_token_category

    https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
  warn(


In [3]:
def calc_model_group_stats(
    tokenized_corpus_dataset: list,
    logprob_datasets: dict[str, Dataset],
    token_groups: dict[int, dict[str, bool]],
    models: list[str],
    token_labels: list[str]
) -> dict[tuple[str, str], dict[str, float]]:
    """
    For each (model, token group) pair, calculate useful stats (for visualization)

    args:
    - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
    - logprob_datasets: a dict of logprob datasets, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
    - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
    - models: a list of model names, e.g. ["llama2", "gpt2", ...]
    - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]

    returns: a dict of (model, token group) pairs to a dict of stats, 
        e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}

    Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, 
    but it's better to be explicit
    
    stats calculated: mean, median, min, max, 25th percentile, 75th percentile
    """
    model_group_stats = {}
    for model in models:
        group_logprobs = {}
        print(f"Processing model {model}")
        dataset = logprob_datasets[model]
        for ix_doc_lp, document_lps in enumerate(dataset):
            tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
            for ix_token, token in enumerate(tokens):
                if ix_token == 0:  # skip the first token, which isn't predicted
                    continue
                logprob = document_lps[ix_token]
                for token_group_desc in token_labels:
                    if token_groups[token][token_group_desc]:
                        if token_group_desc not in group_logprobs:
                            group_logprobs[token_group_desc] = []
                        group_logprobs[token_group_desc].append(logprob)
        for token_group_desc in token_labels:
            if token_group_desc in group_logprobs:
                model_group_stats[(model, token_group_desc)] = {
                    "mean": np.mean(group_logprobs[token_group_desc]),
                    "median": np.median(group_logprobs[token_group_desc]),
                    "min": np.min(group_logprobs[token_group_desc]),
                    "max": np.max(group_logprobs[token_group_desc]),
                    "25th": np.percentile(group_logprobs[token_group_desc], 25),
                    "75th": np.percentile(group_logprobs[token_group_desc], 75),
                }
        print()
    return model_group_stats

In [4]:
tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))["validation"]

In [6]:
#TODO: convert to use static paths

with open("../src/delphi/eval/labelled_token_ids_dict.pkl", "rb") as f:
    token_groups = pickle.load(f)

In [7]:
# token groups is a dict of (int -> dict(str -> bool)). The top keys are the token ids, and the values are dicts of token group names to boolean values
# of whether the token is in that group. We want to turn this into a dict of (str -> list(int)) where the keys are the token group names and the values are lists of token ids

token_group_members = {}
for token_id, group_dict in token_groups.items():
    for group_name, is_member in group_dict.items():
        if is_member:
            if group_name not in token_group_members:
                token_group_members[group_name] = []
            token_group_members[group_name].append(token_id)


In [8]:
logprob_datasets = {}
for model in constants.LLAMA2_MODELS:
    logprob_datasets[model] = cast(
        dict,
        cast(Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs"))[
            "validation"
        ],
    )["logprobs"]

In [9]:
token_group_descriptions = list(token_groups[0].keys())

In [10]:
tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))["validation"]

model_group_stats = calc_model_group_stats(
    tokenized_corpus_dataset, logprob_datasets, token_groups, constants.LLAMA2_MODELS, token_groups[0].keys()
)

Processing model delphi-llama2-100k

Processing model delphi-llama2-200k

Processing model delphi-llama2-400k

Processing model delphi-llama2-800k

Processing model delphi-llama2-1.6m

Processing model delphi-llama2-3.2m

Processing model delphi-llama2-6.4m

Processing model delphi-llama2-12.8m

Processing model delphi-llama2-25.6m



In [12]:

performance_data = defaultdict(dict)
for model in constants.LLAMA2_MODELS:
    for token_group_desc in token_group_descriptions:
        if (model, token_group_desc) not in model_group_stats:
            continue
        stats = model_group_stats[(model, token_group_desc)]
        performance_data[model][token_group_desc] = (-stats["median"], -stats["75th"], -stats["25th"])

visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color="Red", marker_color='Orange', bar_color='Green')

VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…

In [13]:
# with open("../data/model_group_stats.pkl", "wb") as f:
#     pickle.dump(model_group_stats, f)

In [14]:
# import pickle
# with open("../data/model_group_stats.pkl", "rb") as f:
#     model_group_stats = pickle.load(f)