# core

> This module contains the core functions for calculating the perplexity of a language model per token

In [None]:
#| default_exp core

In [2]:
#| export
import datasets
import os
import torch
import transformers

import torch.nn.functional as F

from collections import Counter, defaultdict
from rich.progress import track
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import default_data_collator

In [3]:
#| hide
from datasets import load_dataset
from nbdev.showdoc import *
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
#| export
def loss_func(
    logits,                 # the model's output
    labels,                 # the labels to calculate the cross entropy loss against
    use_custom_loss=False   # whether to use the custom loss function
):                          # the loss per token of shape (batch_size, seq_len)
    """
    Calculates the cross entropy loss for the model's output and the labels.
    """
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(reduction="none")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    loss = loss.view(*shift_labels.size())
    return loss

In [5]:
# test loss function
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
labels = inputs.input_ids
loss_func(logits, labels)

tensor([[2.3432, 3.7964, 6.6038, 1.7265, 5.4809],
        [2.3432, 3.7964, 6.6038, 1.7265, 5.4809]], grad_fn=<ViewBackward0>)

In [6]:
#| export
def get_counts(
    model,                      # the model to use for predictions
    tokenizer,                  # the tokenizer to use for encoding
    batch,                      # the batch to use for predictions
    semantic_column: str,       # the column to use for semantic predictions
    return_distributions: bool  # whether to return the distributions
):                              # the counts for the losses and tokens
    """
    Returns the counts for the losses and tokens.
    """
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]

    with torch.no_grad():
        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            labels=input_ids,
            return_dict=True
        )
    loss = loss_func(outputs.logits, input_ids)

    # Add the losses to the counter for each 
    # token in the input
    loss_cnt = defaultdict(list) if return_distributions else Counter()
    token_cnt = Counter()
    for i, ids in enumerate(input_ids):
        for j, token in enumerate(ids[1:]):
            token = tokenizer.decode(token)
            loss_cnt[token] += [loss[i][j].item()] if return_distributions else loss[i][j].item()
            token_cnt[token] += 1

            if semantic_column != None and token != tokenizer.pad_token:
                semantic = batch[semantic_column][i][j]
                loss_cnt[semantic] += [
                    loss[i][j].item()
                ] if return_distributions else loss[i][j].item()
                token_cnt[semantic] += 1

    return loss_cnt, token_cnt

In [7]:
#| export
def perplexed(
    model: transformers.PreTrainedModel, # The model to calculate the perplexity of.
    dataset: datasets.Dataset, # The dataset to calculate the perplexity on.
    tokenizer: transformers.PreTrainedTokenizer = None, # The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
    column: str = "text", # The column of the dataset to calculate the perplexity on.
    semantic_column: str = None, # The column of the dataset to calculate the semantic perplexity on such as NER tags.
    n_gram: int = 1, # The n-gram to calculate the perplexity on.
    batch_size: int = 1, # The batch size to use when calculating the perplexity.
    num_proc: int = os.cpu_count(), # The number of processes to use when tokenizing the dataset.
    device: str = "cuda", # The device to use when calculating the perplexity.
    collate_fn = default_data_collator, # The collate function to use when calculating the perplexity.
    pass_row: bool = False, # Whether to pass the row to the tokenizer.
    return_tokens: bool = False, # Whether to return the tokens counts along with the perplexity.
    return_distributions: bool = False, # Whether to return the perplexity distributions instead of the perplexity.
    compute_perplexity: bool = True, # Whether to compute the perplexity. If False, the cross entropy will be returned instead.
): # The perplexity of the model on the dataset or a tuple of the perplexity and the token counts.
    """
    Calculate the perplexity of a model on a dataset.
    """
    if tokenizer is None:
        tokenizer = model.config.tokenizer_class.from_pretrained(model.config.pretrained_model_name_or_path)

    # Tokenize the dataset
    batched = batch_size > 1
    tokenized_dataset = dataset.map(
        lambda x: tokenizer(x[column], truncation=True, padding="max_length")
        if not pass_row else tokenizer(x, truncation=True, padding="max_length"),
        batched=batched,
        batch_size=batch_size,
        remove_columns=dataset.column_names,
        num_proc=num_proc,
        desc="Tokenizing dataset"
    )

    # Create a dataloader for the dataset
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Calculate the perplexity of the model on the dataset
    total_loss_cnt = defaultdict(list) if return_distributions else Counter()
    total_token_cnt = Counter()
    for batch in track(dataloader, description="Calculating perplexity"):
        # Move the batch to the device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        loss_cnt, token_cnt = get_counts(
            model,
            tokenizer,
            batch,
            semantic_column,
            return_distributions
        )
        for token, loss in loss_cnt.items():
            total_loss_cnt[token] += loss
        total_token_cnt += token_cnt
    
    # Calculate the perplexity
    perplexity = defaultdict(list) if return_distributions else Counter()
    for token, loss in total_loss_cnt.items():
        if compute_perplexity:
            if return_distributions:
                perplexity[token] = list(map(lambda x: torch.exp(torch.tensor(x)).item(), loss))
            else:
                perplexity[token] = torch.exp(torch.tensor(loss / total_token_cnt[token])).item()
        else:
            if return_distributions:
                perplexity[token] = loss
            else:
                perplexity[token] = loss / total_token_cnt[token]
    
    if return_tokens:
        return perplexity, total_token_cnt
    
    return perplexity

# Perplexity per token

In [8]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to("cpu")

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test").select(range(50))
# filter out empty strings
dataset = dataset.filter(lambda x: len(x["text"]) > 0)

perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=1,
    device="cpu",
    num_proc=1,
    return_tokens=True
)

Found cached dataset wikitext (/transformers_cache/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /transformers_cache/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-9769e73f0000d75f.arrow
Loading cached processed dataset at /transformers_cache/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0b3bb22b8be942e8.arrow


Output()

In [9]:
perplexity_cnt.most_common(10)

[(' wired', 60982756.0),
 (' shatter', 12281804.0),
 (' Career', 4064324.75),
 (' Early', 2422916.0),
 (' Television', 2325886.25),
 (' Daylight', 2126374.75),
 (' unrecogn', 1731017.0),
 (' @', 1636278.125),
 (' Chou', 1440177.375),
 (' advisers', 1118574.375)]

In [10]:
perplexity_cnt.most_common()[-10:]

[('mers', 1.0360190868377686),
 ('mith', 1.0183565616607666),
 ('t', 1.017022728919983),
 (' than', 1.0093393325805664),
 ('jiang', 1.0054292678833008),
 ('ian', 1.0042656660079956),
 ('aire', 1.0030004978179932),
 ('el', 1.0017069578170776),
 ('ights', 1.0014889240264893),
 ('sworth', 1.0009146928787231)]

In [11]:
token_cnt.most_common(10)

[('<|endoftext|>', 52832),
 (' the', 114),
 (',', 107),
 ('.', 83),
 (' "', 72),
 (' in', 69),
 (' of', 52),
 (' a', 44),
 (' =', 41),
 (' and', 40)]

In [12]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

'<|endoftext|>': 30566.986328125
' the': 4.492286205291748
',': 16.437183380126953
'.': 9.6357421875
' "': 9.51689338684082
' in': 7.4871673583984375
' of': 3.4485690593719482
' a': 8.229151725769043
' =': 51.0925178527832
' and': 5.262114524841309


In [13]:
cross_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=1,
    device="cpu",
    num_proc=1,
    return_tokens=True,
    compute_perplexity=False,
)

Loading cached processed dataset at /transformers_cache/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0b3bb22b8be942e8.arrow


Output()

KeyboardInterrupt: 

In [44]:
cross_cnt.most_common(10)

[(' wired', 17.926109313964844),
 (' shatter', 16.32363510131836),
 (' Career', 15.217745780944824),
 (' Early', 14.700493812561035),
 (' Television', 14.659614562988281),
 (' Daylight', 14.569916725158691),
 (' unrecogn', 14.364232063293457),
 (' @', 14.307934696024114),
 (' Chou', 14.180286407470703),
 (' advisers', 13.92755126953125)]

In [45]:
cross_cnt.most_common()[-10:]

[('mers', 0.035385108552873135),
 ('mith', 0.018190178088843822),
 ('t', 0.016879817470908165),
 (' than', 0.009296108968555927),
 ('jiang', 0.005414582323282957),
 ('ian', 0.0042567127384245396),
 ('aire', 0.0029960053507238626),
 ('el', 0.001705383649095893),
 ('ights', 0.00148781668394804),
 ('sworth', 0.000914393924176693)]

In [46]:
# cross entropy of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {cross_cnt[token]}")

'<|endoftext|>': 10.327674743941088
' the': 1.5023615921667792
',': 2.799546029940944
'.': 2.2654792655663316
' "': 2.253068277819289
' in': 2.0131916974981627
' of': 1.2379600099744634
' a': 2.1076847396113654
' =': 3.933640957242105
' and': 1.6605336494743823


# Perplexity per semantic type

The following cells contain the code for calculating the perplexity per semantic type of a tokenizer for aligning the AST of a program with the BPE of a language model's tokenizer.

In [12]:
!pip install code_tokenizers
!download_grammars



In [13]:
from code_tokenizers.core import CodeTokenizer

def code_collator(batch):
    merged_ast = []
    for b in batch:
        merged_ast.append(b.pop("merged_ast"))
    batch = default_data_collator(batch)
    batch["merged_ast"] = merged_ast
    return batch

model_name = "codeparrot/codeparrot-small"
py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python")
py_tokenizer.node_types.append("as_pattern_target")
py_tokenizer.tokenizer.pad_token = py_tokenizer.tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

dataset = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(range(10))
perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=py_tokenizer,
    column="content",
    semantic_column="merged_ast",
    batch_size=1,
    num_proc=1,
    device="cpu",
    collate_fn=code_collator,
    return_tokens=True
)

Using custom data configuration codeparrot--codeparrot-clean-valid-826c6fd8b27e5523
Found cached dataset json (/home/nathan/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-826c6fd8b27e5523/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


Tokenizing dataset:   0%|          | 0/10 [00:00<?, ?ex/s]

Output()

In [14]:
perplexity_cnt.most_common(10)

[('reports', 3101943.25),
 ('Double', 2928004.0),
 ('BLANK', 2727696.0),
 ('BD', 912393.3125),
 ('CO', 745110.5625),
 ('Pure', 499221.3125),
 ('customize', 465415.96875),
 (' inte', 430242.40625),
 (' ways', 416678.46875),
 (' filenames', 415019.5625)]

In [15]:
token_cnt.most_common(10)

[('<module -> comment>', 3023),
 ('<|endoftext|>', 2557),
 ('<import_from_statement -> from>', 819),
 ('< N/A >', 667),
 ('<argument_list -> string>', 575),
 ('<attribute -> identifier>', 551),
 ('<expression_statement -> string>', 489),
 ('<dotted_name -> identifier>', 463),
 ('_', 361),
 ('.', 355)]

In [16]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

'<module -> comment>': 402.9506530761719
'<|endoftext|>': 9540.9033203125
'<import_from_statement -> from>': 5876.1806640625
'< N/A >': 4.411830425262451
'<argument_list -> string>': 8.249506950378418
'<attribute -> identifier>': 1.6259615421295166
'<expression_statement -> string>': 9.663031578063965
'<dotted_name -> identifier>': 2.7417614459991455
'_': 1.3683314323425293
'.': 1.3909822702407837


In [1]:
#| hide
import nbdev; nbdev.nbdev_export()