In [None]:
import re
import torch
import numpy as np
from datasets import load_dataset
# from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

def wikitext_detokenizer(text):
    """Detokenize WikiText exactly as lm-eval does"""
    string = text
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")
    return string

def compute_perplexity(model, tokenizer, dataset_split='test', max_length=2048):
    """
    Compute WikiText perplexity similar to lm-evaluation-harness.
    
    This implementation uses DISJOINT (non-overlapping) windows, 
    which matches the default 'loglikelihood' task in lm-eval.
    
    Args:
        model: HuggingFace model (use model.eval() and model.to(device) before)
        tokenizer: HuggingFace tokenizer
        dataset_split: 'test', 'validation', or 'train'
        max_length: Maximum sequence length (model's context window)
    """
    device = next(model.parameters()).device
    
    # Load WikiText dataset
    dataset = load_dataset("EleutherAI/wikitext_document_level", 
                          "wikitext-2-raw-v1", 
                          split=dataset_split)
    
    total_loglikelihood = 0.0
    total_words = 0
    
    print(f"Computing perplexity on {len(dataset)} documents using disjoint windows...")
    
    for doc in tqdm(dataset):
        # Get original text and detokenize
        original_text = doc['page']
        detokenized_text = wikitext_detokenizer(original_text)
        
        # 1. Count words on ORIGINAL text (Your implementation is correct)
        words = len(re.split(r"\s+", original_text))
        
        # 2. Tokenize the DETOKENIZED text (Your implementation is correct)
        tokens = tokenizer.encode(detokenized_text, add_special_tokens=False)
        
        if not tokens:
            continue
        
        # 3. Prepend EOS token (Correct)
        tokens = [tokenizer.eos_token_id] + tokens
        
        # 4. Create DISJOINT windows (This is the corrected part)
        doc_loglikelihood = 0.0
        for i in range(0, len(tokens), max_length):
            window_tokens = tokens[i : i + max_length]
            
            # Skip windows that are too short to have a label
            if len(window_tokens) < 2:
                continue
                
            # Prepare input
            input_ids = torch.tensor([window_tokens], device=device)
            
            # Get model outputs (logits)
            with torch.no_grad():
                outputs = model(input_ids)
                logits = outputs.logits
            
            # 5. Compute log-likelihood (Your implementation is correct)
            # Shift: predict tokens 1..N given tokens 0..N-1
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            # Compute log probabilities
            log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
            
            # Gather log probs for actual tokens
            token_log_probs = torch.gather(
                log_probs, 
                2, 
                shift_labels.unsqueeze(-1)
            ).squeeze(-1)
            
            # Sum log probs for this window
            doc_loglikelihood += token_log_probs.sum().item()
        
        total_loglikelihood += doc_loglikelihood
        total_words += words
    
    # 6. Calculate word perplexity (Your implementation is correct)
    word_perplexity = np.exp(-total_loglikelihood / total_words)
    
    return word_perplexity

# Example usage
# if __name__ == "__main__":
#     # Load your model
#     # model_name = "gpt2"  # Replace with your model
#     print(f"Loading model: {model_name}")
    
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model = AutoModelForCausalLM.from_pretrained(model_name)
    
#     model = model.to("cuda")
#     model.eval()
    
#     results = compute_perplexity(model, tokenizer, dataset_split='test')
    
#     print("\n" + "="*50)
#     print("WikiText-2 Perplexity Results")
#     print(f"Word Perplexity: {results:.2f}")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset

dataset = load_dataset("EleutherAI/wikitext_document_level", "wikitext-2-raw-v1", split="all")

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
dataset = load_dataset("cais/mmlu", "all", split="auxiliary_train")

In [9]:
dataset[100]

{'question': 'Storekeeper, the owner of a large hardware store, sells power saws for both personal and commercial use. He often takes old power saws as trade-ins on new ones. The old power saws are then completely disassembled and rebuilt with new bearings by Storekeeper\'s employees and sold by Storekeeper as "reconditioned saws." Purchaser, the owner and operator of a cabinetmaking shop, informed Storekeeper that he wanted to buy a reconditioned circular saw for use in his cabinetmaking business. However, the blade that was on the saw he picked out had very coarse teeth for cutting rough lumber. Purchaser told Storekeeper that he wanted a saw blade that would cut plywood. Storekeeper exchanged the coarse blade for a new one with finer teeth that would cut plywood smoothly. The new blade was manufactured by Saw-Blade Company, which uses all available techniques to inspect its products for defects. The reconditioned saw had been manufactured by Power Saw Company. The week after the saw

In [40]:
pairs = []
for doc in tqdm(dataset):
    original_text = doc['page']
    detokenized_text = wikitext_detokenizer(original_text)
    
    words = len(re.split(r"\s+", original_text))
        
    pairs.append((original_text, detokenized_text, words))

100%|██████████| 62/62 [00:00<00:00, 478.26it/s]


In [31]:
page = 1
print(pairs[page][2])

3934


In [34]:
sum = 0
max_words = 0
min_words = float('inf')
for _, _, words in pairs:
    sum += words
    if words > max_words:
        max_words = words
    if words < min_words:
        min_words = words
print(sum / len(pairs))
print(max_words)
print(min_words)

3264.1780604133546
17708
10
