In [None]:
import torch
import torch.nn.functional as F

import math
from tqdm import tqdm
import argparse

import json
import glob
import os

from open_lm.params import parse_args
from open_lm.model import test_perplexity_model

from transformers import GPTNeoXTokenizerFast

In [None]:
args = parse_args([])
args.model = "open_lm_160m"

In [None]:
############################ SET THOSE VALUES #################################

# Set the path for the pretrained model to be evaluated
args.classif_model_path = "pretrained_models/C4.pt"

#Set the device
device = 'cuda'

# Set the directory containing the evaluation .jsonl files (could be one or more files, it will iterate over all)
input_dir = "cross_dataset"
#input_dir = "paloma" 
#input_dir = "wikitext_103"
 

In [None]:
# Load model and move to device
model = test_perplexity_model(args)
model = model.to(device)
model.eval()

# Automatically find all .jsonl files in the directory
input_files = sorted(glob.glob(os.path.join(input_dir, "*.jsonl")))

# Load tokenizer
tokenizer = GPTNeoXTokenizerFast.from_pretrained('EleutherAI/gpt-neox-20b')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Parameters
batch_size = 32
max_length = 256

In [None]:
def calculate_perplexity(model, tokenizer, texts, batch_size, max_length, device):
    model.to(device).eval()
    total_log_likelihood = 0.0
    total_tokens = 0

    loss_fct = torch.nn.CrossEntropyLoss(
        ignore_index=tokenizer.pad_token_id,
        reduction='sum'
    )

    for i in tqdm(range(0, len(texts), batch_size), desc="Batches"):
        batch = texts[i : i + batch_size]
        batch_texts = [t["text"] for t in batch]

        enc = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        ).to(device)

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        targets = input_ids.clone()
        targets[:, :-1] = input_ids[:, 1:]
        targets[:, -1] = tokenizer.pad_token_id

        with torch.no_grad():
            logits = model(input_ids)[0]

        B, L, V = logits.size()
        logits_flat = logits.view(-1, V)
        targets_flat = targets.view(-1)

        loss_sum = loss_fct(logits_flat, targets_flat).item()
        non_pad = (targets_flat != tokenizer.pad_token_id).sum().item()

        total_log_likelihood += -loss_sum
        total_tokens += non_pad

    if total_tokens == 0:
        return float("inf"), 0.0, 0

    ppl = math.exp(-total_log_likelihood / total_tokens)
    return ppl, total_log_likelihood, total_tokens

# Compute perplexity
overall_ll = 0.0
overall_tok = 0

print("Perplexity by file:")
for fname in input_files:
    with open(fname, 'r') as f:
        texts = [json.loads(line) for line in f]

    ppl, ll, tok = calculate_perplexity(
        model, tokenizer, texts,
        batch_size=batch_size,
        max_length=max_length,
        device=device
    )
    print(f"  {os.path.basename(fname)}: {ppl:.2f}")
    overall_ll += ll
    overall_tok += tok

if overall_tok > 0:
    overall_ppl = math.exp(-overall_ll / overall_tok)
    print(f"\nOverall Perplexity: {overall_ppl:.2f}")
else:
    print("\nNo tokens processed; cannot compute overall perplexity.")

print(f"Total tokens: {overall_tok:,}")
