# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
import datasets
import os
import torch
import transformers

from collections import Counter, defaultdict
from rich.progress import track
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import default_data_collator

In [None]:
#| hide
from datasets import load_dataset
from nbdev.showdoc import *
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
#| export
def loss_func(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # reshape to (batch_size, sequence_length)
    loss = loss.view(*shift_labels.size())
    return loss

In [None]:
#| export
def get_counts(model, tokenizer, batch, semantic_column: str, return_distributions: bool):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids, return_dict=True)
    loss = loss_func(outputs.logits, input_ids)

    # Add the losses to the counter for each 
    # token in the input
    loss_cnt = defaultdict(list) if return_distributions else Counter()
    token_cnt = Counter()
    for i, ids in enumerate(input_ids):
        for j, token in enumerate(ids[1:]):
            token = tokenizer.decode(token)
            loss_cnt[token] += [loss[i][j].item()] if return_distributions else loss[i][j].item()
            token_cnt[token] += 1

            if semantic_column != None:
                semantic = batch[semantic_column][i][j]
                loss_cnt[semantic] += [loss[i][j].item()] if return_distributions else loss[i][j].item()
                token_cnt[semantic] += 1

    return loss_cnt, token_cnt

In [None]:
#| export
def perplexed(
    model: transformers.PreTrainedModel, # The model to calculate the perplexity of.
    dataset: datasets.Dataset, # The dataset to calculate the perplexity on.
    tokenizer: transformers.PreTrainedTokenizer = None, # The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
    column: str = "text", # The column of the dataset to calculate the perplexity on.
    semantic_column: str = None, # The column of the dataset to calculate the semantic perplexity on such as NER tags.
    n_gram: int = 1, # The n-gram to calculate the perplexity on.
    batch_size: int = 1, # The batch size to use when calculating the perplexity.
    num_proc: int = os.cpu_count(), # The number of processes to use when tokenizing the dataset.
    device: str = "cuda", # The device to use when calculating the perplexity.
    collate_fn = default_data_collator, # The collate function to use when calculating the perplexity.
    return_tokens: bool = False, # Whether to return the tokens counts along with the perplexity.
    return_distributions: bool = False, # Whether to return the perplexity distributions instead of the perplexity.
): # The perplexity of the model on the dataset or a tuple of the perplexity and the token counts.
    """
    Calculate the perplexity of a model on a dataset.
    """
    if tokenizer is None:
        tokenizer = model.config.tokenizer_class.from_pretrained(model.config.pretrained_model_name_or_path)

    # Tokenize the dataset
    batched = batch_size > 1
    tokenized_dataset = dataset.map(
        lambda x: tokenizer(x[column], truncation=True, padding="max_length"),
        batched=batched,
        batch_size=batch_size,
        remove_columns=dataset.column_names,
        num_proc=num_proc,
    )

    # Create a dataloader for the dataset
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Calculate the perplexity of the model on the dataset
    total_loss_cnt = defaultdict(list) if return_distributions else Counter()
    total_token_cnt = Counter()
    for batch in track(dataloader, description="Calculating perplexity"):
        # Move the batch to the device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        loss_cnt, token_cnt = get_counts(model, tokenizer, batch, semantic_column, return_distributions)
        for token, loss in loss_cnt.items():
            total_loss_cnt[token] += loss
        total_token_cnt += token_cnt
    
    # Calculate the perplexity
    perplexity = defaultdict(list) if return_distributions else Counter()
    for token, loss in total_loss_cnt.items():
        if return_distributions:
            perplexity[token] = list(map(lambda x: 2 ** x, loss))
        else:
            perplexity[token] = torch.exp(torch.tensor(loss / total_token_cnt[token])).item()
    
    if return_tokens:
        return perplexity, total_token_cnt
    
    return perplexity

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to("cpu")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test").select(range(50))
# filter out empty strings
dataset = dataset.filter(lambda x: len(x["text"]) > 0)

perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=2,
    device="cpu",
    return_tokens=True
)

Found cached dataset wikitext (/home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-9769e73f0000d75f.arrow
num_proc must be <= 27. Reducing num_proc to 27 for dataset of size 27.


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-50db7d0aa7f69ed9.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-c41d930814755e48.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-bed28d7ce7e2358a.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-a83f3e59f7f0fb48.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-aa3afef31db128bc.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-d4bb736cbb4f879c.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-177001bdb192bb94.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6c952d5f40cbdc4f.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-7f5411d1fccb8a4c.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-9f978365f30b95e9.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-3c0db1443a93f072.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-9b895669be345e18.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-63b345793f9998fb.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-56ff420c0cbf561e.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-7571622f774d4369.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-adb710b9646fbd88.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-acafd3367767f91a.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-d0d2ba5f9fc3e083.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-aa680afd71a072fe.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-87a0014e9435a2b2.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-459554c2368f9f0e.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-fd860e1feeb25e0f.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ad9b53a78d118f45.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0ef457c6e6ab05bf.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-16acfa1618f94d3b.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-05ea07e37bf52f7c.arrow


 

Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-65202076bc760807.arrow


Output()

KeyboardInterrupt: 

In [None]:
perplexity_cnt.most_common(10)

[(' wired', 60983220.0),
 (' shatter', 12281874.0),
 (' Career', 4064274.5),
 (' Early', 2422943.75),
 (' Television', 2325893.0),
 (' Daylight', 2126348.5),
 (' unrecogn', 1731038.5),
 (' @', 1636278.125),
 (' Chou', 1440191.125),
 (' advisers', 1118558.375)]

In [None]:
token_cnt.most_common(10)

[('<|endoftext|>', 52832),
 (' the', 114),
 (',', 107),
 ('.', 83),
 (' "', 72),
 (' in', 69),
 (' of', 52),
 (' a', 44),
 (' =', 41),
 (' and', 40)]

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

In [None]:
def code_collator(batch):
    merged_ast = []
    for b in batch:
        merged_ast.append(b.pop("merged_ast"))
    batch = default_data_collator(batch)
    batch["merged_ast"] = merged_ast
    return batch

In [None]:
from code_tokenizers.core import CodeTokenizer

model_name = "codeparrot/codeparrot-small"
py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python")
py_tokenizer.node_types.append("as_pattern_target")
py_tokenizer.tokenizer.pad_token = py_tokenizer.tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

dataset = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))
perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=py_tokenizer,
    column="content",
    semantic_column="merged_ast",
    batch_size=1,
    num_proc=1,
    device="cpu",
    collate_fn=code_collator,
    return_tokens=True
)

Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523
Found cached dataset json (/home/nathan/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/10 [00:00<?, ?ex/s]

Output()

In [None]:
perplexity_cnt.most_common(10)

[('reports', 4496803.0),
 ('Double', 4140204.0),
 ('BLANK', 3750757.25),
 ('BD', 1003989.75),
 ('CO', 805091.3125),
 ('Pure', 525452.5625),
 ('customize', 488133.78125),
 (' inte', 449579.28125),
 (' ways', 434788.71875),
 (' filenames', 432984.59375)]

In [None]:
token_cnt.most_common(10)

[('<module -> comment>', 3023),
 ('<|endoftext|>', 2557),
 ('<import_from_statement -> from>', 819),
 ('< N/A >', 667),
 ('<argument_list -> string>', 575),
 ('<attribute -> identifier>', 551),
 ('<expression_statement -> string>', 489),
 ('<dotted_name -> identifier>', 463),
 ('_', 361),
 ('.', 355)]

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

'<module -> comment>': 403.2505187988281
'<|endoftext|>': 9550.5166015625
'<import_from_statement -> from>': 5880.80029296875
'< N/A >': 4.412940502166748
'<argument_list -> string>': 8.250838279724121
'<attribute -> identifier>': 1.6259663105010986
'<expression_statement -> string>': 9.665909767150879
'<dotted_name -> identifier>': 2.7604565620422363
'_': 1.3683202266693115
'.': 1.3909730911254883


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()