# core

> This module contains the core functions for calculating the perplexity of a language model per token

In [None]:
# | default_exp core

In [None]:
# | export
import datasets
import os
import torch
import transformers

import torch.nn.functional as F

from collections import Counter, defaultdict
from rich.progress import track
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import default_data_collator

datasets.logging.set_verbosity_error()

In [None]:
# | hide
from datasets import load_dataset
from nbdev.showdoc import *
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# | export
def loss_func(
    logits,  # the model's output
    labels,  # the labels to calculate the cross entropy loss against
):  # the loss per token of shape (batch_size, seq_len)
    """
    Calculates the cross entropy loss for the model's output and the labels.
    """
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = loss.view(*shift_labels.size())
    return loss

In [None]:
# test loss function
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(
    ["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt"
)
outputs = model(**inputs)
logits = outputs.logits
labels = inputs.input_ids
loss_func(logits, labels)

tensor([[2.3432, 3.7964, 6.6038, 1.7265, 5.4809],
        [2.3432, 3.7964, 6.6038, 1.7265, 5.4809]], grad_fn=<ViewBackward0>)

In [None]:
# | export
def get_counts(
    model,  # the model to use for predictions
    tokenizer,  # the tokenizer to use for encoding
    batch,  # the batch to use for predictions
    semantic_column: str,  # the column to use for semantic predictions
    stop_word_column: str,  # the column to use for stop word predictions
    return_distributions: bool,  # whether to return the distributions
):  # the counts for the losses and tokens
    """
    Returns the counts for the losses and tokens.
    """
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]

    with torch.no_grad():
        outputs = model(
            input_ids, attention_mask=attention_mask, labels=input_ids, return_dict=True
        )
    loss = loss_func(outputs.logits, input_ids)

    # Add the losses to the counter for each
    # token in the input
    loss_cnt = defaultdict(list) if return_distributions else Counter()
    token_cnt = Counter()
    for i, ids in enumerate(input_ids):
        for j, token in enumerate(ids[1:]):
            # Skip the stop words
            if stop_word_column != None:
                stop_word = batch[stop_word_column][i][j]
                if stop_word:
                    continue

            token = tokenizer.decode(token)
            loss_cnt[token] += (
                [loss[i][j].item()] if return_distributions else loss[i][j].item()
            )
            token_cnt[token] += 1

            if semantic_column != None and token != tokenizer.pad_token:
                semantic = batch[semantic_column][i][j]
                loss_cnt[semantic] += (
                    [loss[i][j].item()] if return_distributions else loss[i][j].item()
                )
                token_cnt[semantic] += 1

    return loss_cnt, token_cnt

In [None]:
# | export
def perplexed(
    model: transformers.PreTrainedModel,  # The model to calculate the perplexity of.
    dataset: datasets.Dataset,  # The dataset to calculate the perplexity on.
    tokenizer: transformers.PreTrainedTokenizer = None,  # The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
    column: str = "text",  # The column of the dataset to calculate the perplexity on.
    semantic_column: str = None,  # The column of the dataset to calculate the semantic perplexity on such as NER tags.
    stop_word_column: str = None,  # The column of the dataset that contains boolean values indicating whether the token is a stop word.
    n_gram: int = 1,  # The n-gram to calculate the perplexity on.
    batch_size: int = 1,  # The batch size to use when calculating the perplexity.
    num_proc: int = os.cpu_count(),  # The number of processes to use when tokenizing the dataset.
    device: str = "cuda",  # The device to use when calculating the perplexity.
    collate_fn=default_data_collator,  # The collate function to use when calculating the perplexity.
    pass_row: bool = False,  # Whether to pass the row to the tokenizer.
    return_tokens: bool = False,  # Whether to return the tokens counts along with the perplexity.
    return_distributions: bool = False,  # Whether to return the perplexity distributions instead of the perplexity.
    compute_perplexity: bool = True,  # Whether to compute the perplexity. If False, the cross entropy will be returned instead.
):  # The perplexity of the model on the dataset or a tuple of the perplexity and the token counts.
    """
    Calculate the perplexity of a model on a dataset.
    """
    if tokenizer is None:
        tokenizer = model.config.tokenizer_class.from_pretrained(
            model.config.pretrained_model_name_or_path
        )

    # Tokenize the dataset
    batched = batch_size > 1
    tokenize = (
        lambda x: tokenizer(x[column], truncation=True, padding="max_length")
        if not pass_row
        else tokenizer(x, truncation=True, padding="max_length")
    )
    tokenized_dataset = dataset.map(
        lambda x: tokenize(x),
        batched=batched,
        batch_size=batch_size,
        remove_columns=dataset.column_names,
        num_proc=num_proc,
        desc="Tokenizing dataset",
    )

    # Create a dataloader for the dataset
    dataloader = DataLoader(
        tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )

    # Calculate the perplexity of the model on the dataset
    total_loss_cnt = defaultdict(list) if return_distributions else Counter()
    total_token_cnt = Counter()
    for batch in track(dataloader, description="Calculating perplexity"):
        # Move the batch to the device
        batch = {
            k: v.to(device) if isinstance(v, torch.Tensor) else v
            for k, v in batch.items()
        }
        loss_cnt, token_cnt = get_counts(
            model,
            tokenizer,
            batch,
            semantic_column,
            stop_word_column,
            return_distributions,
        )
        for token, loss in loss_cnt.items():
            total_loss_cnt[token] += loss
        total_token_cnt += token_cnt

    # Calculate the perplexity
    perplexity = defaultdict(list) if return_distributions else Counter()
    for token, loss in total_loss_cnt.items():
        if compute_perplexity:
            if return_distributions:
                perplexity[token] = list(
                    map(lambda x: torch.exp(torch.tensor(x)).item(), loss)
                )
            else:
                perplexity[token] = torch.exp(
                    torch.tensor(loss / total_token_cnt[token])
                ).item()
        else:
            if return_distributions:
                perplexity[token] = loss
            else:
                perplexity[token] = loss / total_token_cnt[token]

    if return_tokens:
        return perplexity, total_token_cnt

    return perplexity

# Perplexity per token

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to(device)

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test").select(range(50))
# filter out empty strings
dataset = dataset.filter(lambda x: len(x["text"]) > 0)

perplexity_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=1,
    device=device,
    num_proc=1,
    return_tokens=True,
)
assert len(perplexity_cnt) == len(token_cnt)
assert perplexity_cnt.keys() == token_cnt.keys()

Output()

In [None]:
cross_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=tokenizer,
    column="text",
    batch_size=1,
    device=device,
    num_proc=1,
    return_tokens=True,
    compute_perplexity=False,
)
assert len(cross_cnt) == len(token_cnt)
assert cross_cnt.keys() == token_cnt.keys()

Output()

In [None]:
cross_cnt.most_common(10)

[(' wired', 17.92612648010254),
 (' shatter', 16.32363510131836),
 (' Career', 15.21772575378418),
 (' Early', 14.70047664642334),
 (' Television', 14.659582138061523),
 (' Daylight', 14.56997299194336),
 (' unrecogn', 14.364179611206055),
 (' @', 14.307954322208058),
 (' Chou', 14.180266380310059),
 (' advisers', 13.927596092224121)]

In [None]:
cross_cnt.most_common()[-10:]

[('mers', 0.03539723251014948),
 ('mith', 0.018193976022303104),
 ('t', 0.016906073316931725),
 (' than', 0.009314415045082569),
 ('jiang', 0.005416479427367449),
 ('ian', 0.004262291360646486),
 ('aire', 0.002999095479026437),
 ('el', 0.0017088347813114524),
 ('ights', 0.001490435330197215),
 ('sworth', 0.0009158230968751013)]

In [None]:
# cross entropy of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {cross_cnt[token]}")

'<|endoftext|>': 10.327683209001043
' the': 1.5023754525995046
',': 2.799564078589466
'.': 2.2654987903962653
' "': 2.2530801612883806
' in': 2.0132113315057065
' of': 1.2379778898500193
' a': 2.107695746828209
' =': 3.9336307379530697
' and': 1.6605487003922463


# Perplexity per semantic type

The following cells contain the code for calculating the perplexity per semantic type of a tokenizer for aligning the AST of a program with the BPE of a language model's tokenizer.

In [None]:
!pip install -U code_tokenizers
!download_grammars

In [None]:
from code_tokenizers.core import CodeTokenizer

def code_collator(batch):
    merged_ast = []
    for b in batch:
        merged_ast.append(b.pop("merged_ast"))
    batch = default_data_collator(batch)
    batch["merged_ast"] = merged_ast
    return batch

In [None]:
model_name = "codeparrot/codeparrot-small"
py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python")
py_tokenizer.tokenizer.pad_token = py_tokenizer.tokenizer.eos_token
py_tokenizer.pad_token = py_tokenizer.tokenizer.pad_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

dataset = load_dataset("codeparrot/codeparrot-clean-valid", split="train").select(
    range(15)
)
cross_cnt, token_cnt = perplexed(
    model,
    dataset,
    tokenizer=py_tokenizer,
    column="content",
    semantic_column="merged_ast",
    stop_word_column="is_builtins",
    batch_size=1,
    num_proc=1,
    device=device,
    collate_fn=code_collator,
    return_tokens=True,
    compute_perplexity=False,
)

assert len(cross_cnt) == len(token_cnt)
assert cross_cnt.keys() == token_cnt.keys()

Output()

In [None]:
cross_cnt.most_common(10)

[('reports', 15.318881034851074),
 ('Double', 15.236268043518066),
 ('BLANK', 15.137480735778809),
 ('148', 14.469829559326172),
 ('BD', 13.819499969482422),
 ('year', 13.65689468383789),
 (' filesystem', 13.625283241271973),
 ('CO', 13.59871768951416),
 ('Pure', 13.172009468078613),
 ('customize', 13.098344802856445)]

In [None]:
token_cnt.most_common(10)

[('<|endoftext|>', 3951),
 ('<module -> comment>', 1479),
 ('< N/A >', 1123),
 ('<attribute -> identifier>', 1019),
 ('<argument_list -> string>', 728),
 ('<expression_statement -> string>', 677),
 ('.', 608),
 ('<dotted_name -> identifier>', 608),
 ('_', 434),
 ('\n', 391)]

In [None]:
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10)]
for token in tokens:
    print(f"'{token}': {perplexity_cnt[token]}")

'<|endoftext|>': 30567.21875
'<module -> comment>': 0
'< N/A >': 0
'<attribute -> identifier>': 0
'<argument_list -> string>': 0
'<expression_statement -> string>': 0
'.': 9.635930061340332
'<dotted_name -> identifier>': 0
'_': 0
'
': 3.0456223487854004


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()