# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
import datasets
import torch
import transformers

from collections import Counter, defaultdict
from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| hide
from datasets import load_dataset
from nbdev.showdoc import *
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
#| export
def loss_func(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss

In [None]:
#| export
def get_counts(model, tokenizer, batch, semantic_column: str, return_distributions: bool):
    input_ids = torch.tensor(batch["input_ids"])
    attention_mask = torch.tensor(batch["attention_mask"])
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids, return_dict=True)
    loss = loss_func(outputs.logits, input_ids)
    # Add the losses to the counter for each 
    # token in the input
    loss_cnt = defaultdict(list) if return_distributions else Counter()
    token_cnt = Counter()
    for i, token in enumerate(input_ids[1:]):
        token = tokenizer.decode(token)
        loss_cnt[token] += [loss[i].item()] if return_distributions else loss[i].item()
        token_cnt[token] += 1
    
        if semantic_column != None:
            semantic = batch[semantic_column][i]
            loss_cnt[semantic] += [loss[i].item()] if return_distributions else loss[i].item()
            token_cnt[semantic] += 1
    return loss_cnt, token_cnt

In [None]:
#| export
def perplexed(
    model: transformers.PreTrainedModel, # The model to calculate the perplexity of.
    dataset: datasets.Dataset, # The dataset to calculate the perplexity on.
    tokenizer: transformers.PreTrainedTokenizer = None, # The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
    column: str = "text", # The column of the dataset to calculate the perplexity on.
    semantic_column: str = None, # The column of the dataset to calculate the semantic perplexity on such as NER tags.
    n_gram: int = 1, # The n-gram to calculate the perplexity on.
    batch_size: int = 1, # The batch size to use when calculating the perplexity.
    device: str = "cuda", # The device to use when calculating the perplexity.
    return_tokens: bool = False, # Whether to return the tokens counts along with the perplexity.
    return_distributions: bool = False, # Whether to return the perplexity distributions instead of the perplexity.
): # The perplexity of the model on the dataset or a tuple of the perplexity and the token counts.
    """
    Calculate the perplexity of a model on a dataset.
    """
    if tokenizer is None:
        tokenizer = model.config.tokenizer_class.from_pretrained(model.config.pretrained_model_name_or_path)

    # Tokenize the dataset
    batched = batch_size > 1
    tokenized_dataset = dataset.map(
        lambda x: tokenizer(x[column], truncation=True),
        batched=batched,
        batch_size=batch_size,
        remove_columns=dataset.column_names,
    )

    # TODO: Add support for semantic perplexity

    # Calculate the perplexity of the model on the dataset
    total_loss_cnt = defaultdict(list) if return_distributions else Counter()
    total_token_cnt = Counter()
    for batch in tokenized_dataset:
        loss_cnt, token_cnt = get_counts(model, tokenizer, batch, semantic_column, return_distributions)
        for token, loss in loss_cnt.items():
            total_loss_cnt[token] += loss
        total_token_cnt += token_cnt
    
    # Calculate the perplexity
    perplexity = defaultdict(list) if return_distributions else Counter()
    for token, loss in total_loss_cnt.items():
        if return_distributions:
            perplexity[token] = list(map(lambda x: 2 ** x, loss))
        else:
            perplexity[token] = torch.exp(torch.tensor(loss / total_token_cnt[token])).item()
    
    if return_tokens:
        return perplexity, total_token_cnt
    
    return perplexity

In [None]:
from code_tokenizers.core import CodeTokenizer

model_name = "codeparrot/codeparrot-small"
py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python")
model = AutoModelForCausalLM.from_pretrained(model_name)

dataset = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(5))
perplexity_cnt, token_cnt = perplexed(
    model, dataset,
    tokenizer=py_tokenizer,
    column="content",
    semantic_column="merged_ast",
    batch_size=1,
    device="cpu",
    return_tokens=True
)

Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523
Found cached dataset json (/home/nathan/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 5/5 [00:27<00:00,  5.48s/ex]


In [None]:
perplexity_cnt.most_common(10)

[('reports', 4496803.0),
 ('Double', 4140204.0),
 ('Pure', 525452.5625),
 (' inte', 449579.28125),
 (' ways', 434788.71875),
 (' filenames', 432984.59375),
 ('segments', 418304.90625),
 ('FN', 415667.96875),
 (' Con', 381470.75),
 ('conflict', 329534.9375)]

In [None]:
token_cnt.most_common(10)

[('<module -> comment>', 945),
 ('<expression_statement -> string>', 413),
 ('< N/A >', 313),
 ('<attribute -> identifier>', 243),
 ('.', 196),
 ('<argument_list -> string>', 178),
 ('<dotted_name -> identifier>', 166),
 ('\n', 155),
 ('_', 149),
 (',', 147)]

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

'<module -> comment>': 4.602316856384277
'<expression_statement -> string>': 13.976248741149902
'< N/A >': 4.582609176635742
'<attribute -> identifier>': 1.4728107452392578
'.': 1.4726667404174805
'<argument_list -> string>': 3.3904495239257812
'<dotted_name -> identifier>': 4.801618576049805
'
': 1.9747979640960693
'_': 1.451112985610962
',': 1.654050350189209


In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test").select(range(100))
# filter out empty strings
dataset = dataset.filter(lambda x: len(x["text"]) > 0)

perplexity_cnt, token_cnt = perplexed(model, dataset, tokenizer=tokenizer, column="text", batch_size=1, device="cpu", return_tokens=True)

Found cached dataset wikitext (/home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-fc80113c96285cf5.arrow
100%|██████████| 61/61 [00:00<00:00, 4654.40ex/s]


In [None]:
perplexity_cnt.most_common(10)

[(' wired', 60983452.0),
 (' 768', 21569920.0),
 (' shatter', 12281851.0),
 (' unsett', 8289411.5),
 (' ignited', 6605284.5),
 (' Tanz', 4834806.5),
 (' Influence', 4153385.0),
 (' Career', 4064134.75),
 (' Television', 2325895.0),
 (' Moral', 2243555.25)]

In [None]:
token_cnt.most_common(10)

[(',', 240),
 (' the', 231),
 ('.', 169),
 (' of', 142),
 (' "', 124),
 (' in', 123),
 (' and', 93),
 (' to', 91),
 (' his', 87),
 (' a', 81)]

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

',': 21.052648544311523
' the': 4.730246543884277
'.': 10.683691024780273
' of': 2.6912641525268555
' "': 17.342397689819336
' in': 9.113482475280762
' and': 6.846656322479248
' to': 2.7567787170410156
' his': 11.905125617980957
' a': 8.68340015411377


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()