In [20]:
from transformers import AutoTokenizer, AutoModelForCausalLM


device = "mps"
model_id = "openai-community/gpt2-large"
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [21]:
from datasets import load_dataset


test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors


In [22]:
import torch
from tqdm import tqdm


max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

max_length, stride, seq_len

(1024, 512, 287644)

In [23]:
# Expected number of iterations
seq_len // 512

561

In [24]:
# Verify
for begin_loc in tqdm(range(0, seq_len, stride)):
    pass

100%|██████████| 562/562 [00:00<00:00, 4357114.32it/s]


In [25]:
# Calculate Perplexity
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc+max_length, seq_len)
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)

    target_ids = input_ids.clone()
    target_len = end_loc - prev_end_loc
    target_ids[:, :-target_len] = -100  # first tokens of size stride will be -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over target_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

100%|█████████▉| 560/562 [07:37<00:01,  1.22it/s]


In [27]:
print(f"Perpexity of GPT2-Large: {ppl}")
# https://huggingface.co/docs/transformers/en/perplexity

Perpexity of GPT2-Large: 16.45410919189453
