# Perplexity Evaluation for Pythia-70M with KVPress

This notebook evaluates the perplexity of the Pythia-70M model on Wikitext-2 and PG19 datasets, both with and without KV cache compression using KnormPress.

In [15]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_from_disk
from tqdm import tqdm
import os
import math
import sys
import glob

# Add the parent directory to sys.path to allow importing from accelerated_inference
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

from accelerated_inference.kvpress.presses.knorm_press import KnormPress
from accelerated_inference.attention.patch import patch_attention_functions

In [16]:
MODEL_NAME = "EleutherAI/pythia-70m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET_DIR = "../dataset"

print(f"Loading model: {MODEL_NAME} on {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

# Apply attention patch
patch_attention_functions()

Loading model: EleutherAI/pythia-70m on cuda


In [17]:
def calculate_ppl(model, tokenizer, text, seq_len=2048, stride=512, device="cuda"):
    encodings = tokenizer(text, return_tensors="pt")
    max_length = model.config.max_position_embeddings
    seq_len = min(seq_len, max_length)
    
    nlls = []
    prev_end_loc = 0
    
    for begin_loc in tqdm(range(0, encodings.input_ids.size(1), stride), desc="Calculating PPL"):
        end_loc = min(begin_loc + seq_len, encodings.input_ids.size(1))
        trg_len = end_loc - prev_end_loc
        
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            nlls.append(outputs.loss)

        prev_end_loc = end_loc
        if end_loc == encodings.input_ids.size(1):
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl.item()

## Baseline Evaluation (No Compression)

In [18]:
wikitext_path = os.path.join(DATASET_DIR, "wikitext-2-raw-v1")
if os.path.exists(wikitext_path):
    print("Loading Wikitext-2 from local disk...")
    data = load_from_disk(wikitext_path)
    text = "\n\n".join(data["text"])
    print(f"Total text length: {len(text)} chars")
    print("Evaluating Baseline on Wikitext-2...")
    ppl = calculate_ppl(model, tokenizer, text, device=DEVICE)
    print(f"Baseline Wikitext PPL: {ppl:.2f}")
else:
    print(f"Wikitext dataset not found at {wikitext_path}")

Loading Wikitext-2 from local disk...
Total text length: 1294336 chars
Evaluating Baseline on Wikitext-2...


Calculating PPL:  99%|█████████▉| 560/564 [00:23<00:00, 23.85it/s]

Baseline Wikitext PPL: 39.43





## Evaluation with KnormPress (Compression)

In [19]:
compression_ratio = 0.2
print(f"Evaluating with KnormPress (Compression Ratio: {compression_ratio})...")

press = KnormPress(compression_ratio=compression_ratio)

if os.path.exists(wikitext_path):
    with press(model):
        ppl = calculate_ppl(model, tokenizer, text, device=DEVICE)
    print(f"KnormPress Wikitext PPL: {ppl:.2f}")

Evaluating with KnormPress (Compression Ratio: 0.2)...


Calculating PPL:  99%|█████████▉| 560/564 [00:23<00:00, 24.13it/s]

KnormPress Wikitext PPL: 39.43





## PG19 Evaluation

In [20]:
pg19_path = os.path.join(DATASET_DIR, "pg19")
if os.path.exists(pg19_path):
    print("Loading PG19 from local disk...")
    txt_files = glob.glob(os.path.join(pg19_path, "*.txt"))
    if len(txt_files) > 0:
        data = []
        for f in txt_files:
            try:
                with open(f, 'r', encoding='utf-8') as file:
                    data.append({"text": file.read()})
            except Exception as e:
                print(f"Error reading file {f}: {e}")

        if len(data) > 0:
            # Evaluate on a subset of PG19 to save time, e.g., first 5 books
            num_samples = 5
            print(f"Evaluating on first {num_samples} samples of PG19...")
            
            # Baseline
            total_ppl_baseline = 0
            count = 0
            print("--- Baseline ---")
            for i, sample in enumerate(data):
                if i >= num_samples:
                    break
                print(f"Processing sample {i+1}/{num_samples}...")
                try:
                    ppl = calculate_ppl(model, tokenizer, sample["text"], device=DEVICE)
                    print(f"Sample {i+1} PPL: {ppl:.2f}")
                    total_ppl_baseline += ppl
                    count += 1
                except Exception as e:
                    print(f"Error processing sample {i+1}: {e}")
            
            if count > 0:
                print(f"Average PG19 Baseline PPL: {total_ppl_baseline/count:.2f}")
            
            # Compressed
            total_ppl_compressed = 0
            count = 0
            print(f"--- KnormPress (Ratio: {compression_ratio}) ---")
            with press(model):
                for i, sample in enumerate(data):
                    if i >= num_samples:
                        break
                    print(f"Processing sample {i+1}/{num_samples}...")
                    try:
                        ppl = calculate_ppl(model, tokenizer, sample["text"], device=DEVICE)
                        print(f"Sample {i+1} PPL: {ppl:.2f}")
                        total_ppl_compressed += ppl
                        count += 1
                    except Exception as e:
                        print(f"Error processing sample {i+1}: {e}")
            
            if count > 0:
                print(f"Average PG19 KnormPress PPL: {total_ppl_compressed/count:.2f}")

        else:
             print("No valid text files read from PG19.")
    else:
        print("PG19 dataset is empty (no .txt files found).")
else:
    print(f"PG19 dataset not found at {pg19_path}")

Loading PG19 from local disk...
Evaluating on first 5 samples of PG19...
--- Baseline ---
Processing sample 1/5...


Calculating PPL:  97%|█████████▋| 112/116 [00:04<00:00, 23.78it/s]


Sample 1 PPL: 29.36
Processing sample 2/5...


Calculating PPL:  98%|█████████▊| 166/170 [00:06<00:00, 23.99it/s]


Sample 2 PPL: 50.69
Processing sample 3/5...


Calculating PPL:  97%|█████████▋| 151/155 [00:06<00:00, 23.87it/s]


Sample 3 PPL: 40.74
Processing sample 4/5...


Calculating PPL:  97%|█████████▋| 152/156 [00:06<00:00, 24.08it/s]


Sample 4 PPL: 57.71
Processing sample 5/5...


Calculating PPL:  71%|███████▏  | 10/14 [00:00<00:00, 24.18it/s]


Sample 5 PPL: 29.40
Average PG19 Baseline PPL: 41.58
--- KnormPress (Ratio: 0.2) ---
Processing sample 1/5...


Calculating PPL:  97%|█████████▋| 112/116 [00:04<00:00, 23.80it/s]


Sample 1 PPL: 29.36
Processing sample 2/5...


Calculating PPL:  98%|█████████▊| 166/170 [00:07<00:00, 23.71it/s]


Sample 2 PPL: 50.69
Processing sample 3/5...


Calculating PPL:  97%|█████████▋| 151/155 [00:06<00:00, 23.76it/s]


Sample 3 PPL: 40.74
Processing sample 4/5...


Calculating PPL:  97%|█████████▋| 152/156 [00:06<00:00, 23.39it/s]


Sample 4 PPL: 57.71
Processing sample 5/5...


Calculating PPL:  71%|███████▏  | 10/14 [00:00<00:00, 23.48it/s]

Sample 5 PPL: 29.40
Average PG19 KnormPress PPL: 41.58



