In [9]:
from datasets import load_dataset
import transformers
import torch
from tqdm import tqdm
import math
import re

In [10]:
#Get model, tokenizer, and max_length
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
max_length = tokenizer.model_max_length
bos_token_id = tokenizer.bos_token_id


In [11]:
#Download dataset from HF
dataset = load_dataset(
                path='EleutherAI/wikitext_document_level',
                name='wikitext-2-raw-v1',
                split='test'
            )

In [12]:
sample_data = dataset[0]['page']
print(sample_data)

 = Robert Boulter = 
 
 Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as " Craig " in the episode " Teddy 's Story " of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . 
 In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006

In [13]:
# Return a list of tuples that are (context, continuation)
# This function specifies the context as 1 token at the start of each window
# Note: It looks like for the last window lm eval harness assigns the "extra space" that is left over when the length of the last frame is not equal to the max_length
# to more context tokens. In this excercise will follow the pattern of only having one context token per window
def build_rolling_requests(encoded_tokens, max_length, bos_id):
  requests = []
  # Keep track of the tokens predicted so we know where to index and stop
  tokens_predicted = 0
  # Treat the first sequence different to specify the BOS token
  first_sequence_length = min(max_length, len(encoded_tokens))
  first_sequence = ([bos_id], encoded_tokens[:first_sequence_length])
  requests.append(first_sequence)
  tokens_predicted += first_sequence_length
  # Loop while the predicted tokens is less than the length of encoded tokens
  while tokens_predicted < len(encoded_tokens):
    # Prediction should be the minimum of the max length and what we still have to predict
    next_pred_len = min(max_length, len(encoded_tokens) - tokens_predicted)
    next_pred = encoded_tokens[tokens_predicted: tokens_predicted + next_pred_len]
    # Context will be 1 in this case so we can just take the the last predicted token from the previous window
    context = [encoded_tokens[tokens_predicted-1]]
    requests.append((context, next_pred))
    tokens_predicted += next_pred_len
  return requests



In [14]:
device = torch.device('cpu')
model.to(device)
#Based on the requests calculate the loglikelihood for each page
def get_log_likelihood(model, requests):
  results = []
  summed_results = 0
  for request in tqdm(requests):
    context, continuation = request
    #Convert the input to a tensor to be passed to the model
    inp = torch.tensor(
        (context + continuation)[:-1],
        dtype=torch.long,
        device=device,
    )
    # Add a batch dimension to the input tensor
    inp = inp.unsqueeze(0)
    with torch.no_grad():
      logits = model(inp).logits
    # Normalize logits vocab dimension using log_softmax
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    #Convert continuation tokens to a tensor
    cont_toks = torch.tensor(continuation, dtype=torch.long, device=device).unsqueeze(0)
    # Slice logits to get the logits for the continuation tokens
    # logits has the shape [batch_size, sequence_length, vocab]
    # Slice the sequence from the end of the sequence starting at the length of tokens in the continuation sequence
    log_probs_for_cont = log_probs[:, -cont_toks.shape[1] :, :]

    # Get the log probabilities of the actual continuation tokens
    compare_log_probs = torch.gather(log_probs_for_cont, 2, cont_toks.unsqueeze(-1)).squeeze(-1)

    # Sum the log likelihood for the continuation tokens
    # Sum instead of multiply because log(a*b) = log(a) + log(b)
    log_likelihood = float(compare_log_probs.sum())
    summed_results += log_likelihood
    results.append(log_likelihood)
  return (summed_results, results)

In [15]:
def weighted_perplexity(a, b):
    return math.exp(-(sum(a) / sum(b)))
def get_bits_per_byte(a, b):
    return -(sum(a) / sum(b)) / math.log(2)
#Process_results copied from LM harness
def process_page(doc):
    page_words = len(re.split(r"\s+", doc["page"]))
    page_bytes = len(doc["page"].encode("utf-8"))
    return (page_words, page_bytes)


In [16]:
#Evaluate dataset with a limit
LIMIT = 2
limited_dataset = dataset.select(range(LIMIT))
results = []

for i in limited_dataset:
  page = i['page']
  requests = build_rolling_requests(tokenizer.encode(page), max_length, bos_token_id)
  page_loglikelihood, _ = get_log_likelihood(model, requests)
  page_words, page_bytes = process_page(i)
  results.append((page_loglikelihood, page_words, page_bytes))

all_log_likelihoods, all_words, all_bytes = zip(*results)
word_perplexity = weighted_perplexity(all_log_likelihoods, all_words)
byte_perplexity = weighted_perplexity(all_log_likelihoods, all_bytes)
bits_per_byte = get_bits_per_byte(all_log_likelihoods, all_bytes)
print(f"\n\nWord Perplexity: {word_perplexity}")
print(f"\nByte Perplexity: {byte_perplexity}")
print(f"\nBits Per Byte: {bits_per_byte}")

Token indices sequence length is longer than the specified maximum sequence length for this model (1299 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 2/2 [00:13<00:00,  6.54s/it]
100%|██████████| 6/6 [00:40<00:00,  6.70s/it]



Word Perplexity: 55.790871197851814

Byte Perplexity: 2.17720336666234

Bits Per Byte: 1.1224761720516812



