# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
import datasets
import torch
import transformers

from collections import Counter, defaultdict
from rich.progress import track
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import default_data_collator

In [None]:
#| hide
from datasets import load_dataset
from nbdev.showdoc import *
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
#| export
def loss_func(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # reshape to (batch_size, sequence_length)
    loss = loss.view(*shift_labels.size())
    return loss

In [None]:
#| export
def get_counts(model, tokenizer, batch, semantic_column: str, return_distributions: bool):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    # input_ids = torch.tensor(batch["input_ids"])
    # attention_mask = torch.tensor(batch["attention_mask"])
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids, return_dict=True)
    loss = loss_func(outputs.logits, input_ids)
    # print(loss.shape)
    # print(input_ids.shape)
    # Add the losses to the counter for each 
    # token in the input
    loss_cnt = defaultdict(list) if return_distributions else Counter()
    token_cnt = Counter()
    for i, ids in enumerate(input_ids):
        for j, token in enumerate(ids[1:]):
            token = tokenizer.decode(token)
            loss_cnt[token] += [loss[i][j].item()] if return_distributions else loss[i][j].item()
            token_cnt[token] += 1

            if semantic_column != None:
                semantic = batch[semantic_column][i]
                loss_cnt[semantic] += [loss[i].item()] if return_distributions else loss[i].item()
                token_cnt[semantic] += 1
    # for i, token in enumerate(input_ids[1:]):
    #     token = tokenizer.decode(token)
    #     loss_cnt[token] += [loss[i].item()] if return_distributions else loss[i].item()
    #     token_cnt[token] += 1
    
    #     if semantic_column != None:
    #         semantic = batch[semantic_column][i]
    #         loss_cnt[semantic] += [loss[i].item()] if return_distributions else loss[i].item()
    #         token_cnt[semantic] += 1
    return loss_cnt, token_cnt

In [None]:
#| export
def perplexed(
    model: transformers.PreTrainedModel, # The model to calculate the perplexity of.
    dataset: datasets.Dataset, # The dataset to calculate the perplexity on.
    tokenizer: transformers.PreTrainedTokenizer = None, # The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
    column: str = "text", # The column of the dataset to calculate the perplexity on.
    semantic_column: str = None, # The column of the dataset to calculate the semantic perplexity on such as NER tags.
    n_gram: int = 1, # The n-gram to calculate the perplexity on.
    batch_size: int = 1, # The batch size to use when calculating the perplexity.
    device: str = "cuda", # The device to use when calculating the perplexity.
    collate_fn = default_data_collator, # The collate function to use when calculating the perplexity.
    return_tokens: bool = False, # Whether to return the tokens counts along with the perplexity.
    return_distributions: bool = False, # Whether to return the perplexity distributions instead of the perplexity.
): # The perplexity of the model on the dataset or a tuple of the perplexity and the token counts.
    """
    Calculate the perplexity of a model on a dataset.
    """
    if tokenizer is None:
        tokenizer = model.config.tokenizer_class.from_pretrained(model.config.pretrained_model_name_or_path)

    # Tokenize the dataset
    batched = batch_size > 1
    tokenized_dataset = dataset.map(
        lambda x: tokenizer(x[column], truncation=True, padding="max_length"),
        batched=batched,
        batch_size=batch_size,
        remove_columns=dataset.column_names,
    )

    # Create a dataloader for the dataset
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Calculate the perplexity of the model on the dataset
    total_loss_cnt = defaultdict(list) if return_distributions else Counter()
    total_token_cnt = Counter()
    for batch in track(dataloader, description="Calculating perplexity"):
        # Move the batch to the device
        batch = {k: v.to(device) for k, v in batch.items()}
        loss_cnt, token_cnt = get_counts(model, tokenizer, batch, semantic_column, return_distributions)
        for token, loss in loss_cnt.items():
            total_loss_cnt[token] += loss
        total_token_cnt += token_cnt
    
    # Calculate the perplexity
    perplexity = defaultdict(list) if return_distributions else Counter()
    for token, loss in total_loss_cnt.items():
        if return_distributions:
            perplexity[token] = list(map(lambda x: 2 ** x, loss))
        else:
            perplexity[token] = torch.exp(torch.tensor(loss / total_token_cnt[token])).item()
    
    if return_tokens:
        return perplexity, total_token_cnt
    
    return perplexity

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to("cpu")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test").select(range(10))
# filter out empty strings
dataset = dataset.filter(lambda x: len(x["text"]) > 0)

perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=2,
    device="cpu",
    return_tokens=True
)

Found cached dataset wikitext (/home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-539070ede7844002.arrow
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-62d1e97981c3439b.arrow


Output()

In [None]:
perplexity_cnt.most_common(10)

[(' @', 58884616.0),
 (' Career', 4064274.5),
 (' Citizenship', 2621984.0),
 (' Kier', 998199.375),
 (' Daylight', 385076.9375),
 (' Craig', 367194.25),
 (' Teddy', 301091.5),
 (' Bush', 247001.234375),
 (' Donkey', 122835.953125),
 (' Firm', 120278.03125)]

In [None]:
token_cnt.most_common(10)

[('<|endoftext|>', 9843),
 (' the', 19),
 (' in', 18),
 (',', 15),
 ('.', 14),
 (' by', 10),
 (' =', 9),
 (' a', 9),
 (' television', 8),
 (' B', 7)]

In [None]:
# from code_tokenizers.core import CodeTokenizer

# model_name = "codeparrot/codeparrot-small"
# py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python")
# py_tokenizer.node_types.append("as_pattern_target")
# py_tokenizer.tokenizer.pad_token = py_tokenizer.tokenizer.eos_token
# model = AutoModelForCausalLM.from_pretrained(model_name)

# dataset = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))
# perplexity_cnt, token_cnt = perplexed(
#     model,
#     dataset,
#     tokenizer=py_tokenizer,
#     column="content",
#     semantic_column="merged_ast",
#     batch_size=1,
#     device="cpu",
#     return_tokens=True
# )

In [None]:
perplexity_cnt.most_common(10)

In [None]:
token_cnt.most_common(10)

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()