In [4]:
# Cell 0: Setup Python Path
import sys
import os

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, project_root)

print(f"Project root: {project_root}")
print(f"Python path: {sys.path[:2]}")

Project root: /home/krrish/Desktop/Programming/slm-distill
Python path: ['/home/krrish/Desktop/Programming/slm-distill', '/usr/lib/python311.zip']


In [6]:
# Cell: Comprehensive Perplexity Evaluation

import torch
from transformers import XLMRobertaForMaskedLM, AutoTokenizer
from src.data.data import prepare_datasets
import numpy as np
from tqdm import tqdm
import os

# ============================================================
# CONFIGURATION
# ============================================================
checkpoint = "/home/krrish/Desktop/Programming/slm-distill/outputs/final_model"
data_path = "/home/krrish/Desktop/Programming/slm-distill/data/hin/data-0.parquet"
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
max_length = 128
train_split = 0.95

print("="*70)
print("PERPLEXITY EVALUATION")
print("="*70)
print(f"Checkpoint:  {checkpoint}")
print(f"Data:        {data_path}")
print(f"Device:      {device}")
print("="*70)

# ============================================================
# 1. VERIFY FILES EXIST
# ============================================================
print("\n[1] Checking files...")

if not os.path.exists(checkpoint):
    print(f"❌ Error: Checkpoint not found at {checkpoint}")
    print("Available checkpoints:")
    if os.path.exists("outputs"):
        for item in os.listdir("outputs"):
            print(f"  - {item}")
else:
    print(f"✓ Checkpoint found")

if not os.path.exists(data_path):
    print(f"❌ Error: Data not found at {data_path}")
else:
    print(f"✓ Data file found")

# ============================================================
# 2. LOAD MODEL & TOKENIZER
# ============================================================
print("\n[2] Loading model and tokenizer...")

try:
    tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
    model = XLMRobertaForMaskedLM.from_pretrained(checkpoint)
    model.to(device)
    model.eval()
    
    num_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"✓ Model loaded: {num_params:.1f}M parameters")
    print(f"✓ Tokenizer loaded: vocab size = {len(tokenizer)}")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise

# ============================================================
# 3. LOAD EVALUATION DATA
# ============================================================
print("\n[3] Loading evaluation data...")

try:
    _, eval_loader = prepare_datasets(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        batch_size=batch_size,
        train_split=train_split,
        num_workers=2
    )
    print(f"✓ Loaded {len(eval_loader)} batches")
except Exception as e:
    print(f"❌ Error loading data: {e}")
    raise

# ============================================================
# 4. COMPUTE PERPLEXITY
# ============================================================
print("\n[4] Computing perplexity...")

total_loss = 0.0
total_tokens = 0
batch_losses = []
batch_token_counts = []

try:
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(eval_loader, desc="Evaluating")):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Count masked tokens
            num_tokens = (labels != -100).sum().item()
            batch_loss = outputs.loss.item()
            
            # Accumulate
            total_loss += batch_loss * num_tokens
            total_tokens += num_tokens
            batch_losses.append(batch_loss)
            batch_token_counts.append(num_tokens)
            
            # Memory cleanup
            del outputs, input_ids, attention_mask, labels
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
    
    print(f"✓ Processed {len(eval_loader)} batches")
    
except Exception as e:
    print(f"❌ Error during evaluation: {e}")
    raise

# ============================================================
# 5. CALCULATE METRICS
# ============================================================
print("\n[5] Calculating metrics...")

avg_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(avg_loss)).item()

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"Total tokens:           {total_tokens:,}")
print(f"Total batches:          {len(eval_loader)}")
print(f"Average loss:           {avg_loss:.4f}")
print(f"Perplexity:             {perplexity:.2f}")
print("="*70)

# ============================================================
# 6. DETAILED STATISTICS
# ============================================================
print("\n[6] Batch Loss Statistics:")

print(f"  Mean:                 {np.mean(batch_losses):.4f}")
print(f"  Median:               {np.median(batch_losses):.4f}")
print(f"  Std Dev:              {np.std(batch_losses):.4f}")
print(f"  Min:                  {np.min(batch_losses):.4f}")
print(f"  Max:                  {np.max(batch_losses):.4f}")
print(f"  Q1 (25%):             {np.percentile(batch_losses, 25):.4f}")
print(f"  Q3 (75%):             {np.percentile(batch_losses, 75):.4f}")

print("\n[7] Token Count Statistics:")

print(f"  Mean tokens/batch:    {np.mean(batch_token_counts):,.0f}")
print(f"  Total tokens:         {total_tokens:,}")
print(f"  Min batch size:       {np.min(batch_token_counts):,}")
print(f"  Max batch size:       {np.max(batch_token_counts):,}")

# ============================================================
# 7. SAMPLE SENTENCE EVALUATION
# ============================================================
print("\n[8] Evaluating sample sentences...")

hindi_samples = [
    "मैं स्कूल जाता हूँ।",
    "आज मौसम बहुत अच्छा है।",
    "शिक्षा जीवन का आधार है।",
    "भारतीय संविधान दुनिया का सबसे बड़ा लिखित संविधान है।",
    "जलवायु परिवर्तन एक गंभीर समस्या है।"
]

sentence_ppls = []

with torch.no_grad():
    for sent in hindi_samples:
        inputs = tokenizer(sent, return_tensors="pt").to(device)
        outputs = model(**inputs, labels=inputs["input_ids"])
        ppl = torch.exp(outputs.loss).item()
        sentence_ppls.append(ppl)
        print(f"  PPL: {ppl:7.2f} | {sent[:50]}")

avg_sent_ppl = np.mean(sentence_ppls)
print(f"\n✓ Average sample perplexity: {avg_sent_ppl:.2f}")

# ============================================================
# 8. FINAL SUMMARY
# ============================================================
print("\n" + "="*70)
print("EVALUATION SUMMARY")
print("="*70)
print(f"✓ Checkpoint:           {checkpoint}")
print(f"✓ Model size:           {num_params:.1f}M parameters")
print(f"✓ Data file:            {data_path}")
print(f"✓ Tokens evaluated:     {total_tokens:,}")
print(f"✓ Batches:              {len(eval_loader)}")
print(f"✓ Dataset perplexity:   {perplexity:.2f}")
print(f"✓ Sample avg perplexity: {avg_sent_ppl:.2f}")
print("="*70)

# Return results for further use
results = {
    "perplexity": perplexity,
    "avg_loss": avg_loss,
    "total_tokens": total_tokens,
    "batch_losses": batch_losses,
    "sample_ppls": sentence_ppls
}

print("\n✅ Evaluation complete!")

PERPLEXITY EVALUATION
Checkpoint:  /home/krrish/Desktop/Programming/slm-distill/outputs/final_model
Data:        /home/krrish/Desktop/Programming/slm-distill/data/hin/data-0.parquet
Device:      cuda

[1] Checking files...
✓ Checkpoint found
✓ Data file found

[2] Loading model and tokenizer...
✓ Model loaded: 33.1M parameters
✓ Tokenizer loaded: vocab size = 250002

[3] Loading evaluation data...

Loading data from parquet...
Data path: /home/krrish/Desktop/Programming/slm-distill/data/hin/data-0.parquet
✓ Loaded 174,763 examples
Columns: ['doc_id', 'text', 'type']
✓ Train: 166,024 examples (95.0%)
✓ Eval: 8,739 examples (5.0%)
Tokenizing...


Tokenizing train:   0%|          | 0/166024 [00:00<?, ? examples/s]

Tokenizing eval:   0%|          | 0/8739 [00:00<?, ? examples/s]

✓ Train batches: 5,189
✓ Eval batches: 274
✓ Approx train tokens: 21.3M
✓ Loaded 274 batches

[4] Computing perplexity...


Evaluating:   0%|          | 0/274 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Evaluating: 100%|██████████| 274/274 [00:10<00:00, 25.37it/s]


✓ Processed 274 batches

[5] Calculating metrics...

RESULTS
Total tokens:           163,553
Total batches:          274
Average loss:           3.0060
Perplexity:             20.21

[6] Batch Loss Statistics:
  Mean:                 3.0041
  Median:               2.9983
  Std Dev:              0.1870
  Min:                  2.5068
  Max:                  3.6110
  Q1 (25%):             2.8684
  Q3 (75%):             3.1294

[7] Token Count Statistics:
  Mean tokens/batch:    597
  Total tokens:         163,553
  Min batch size:       55
  Max batch size:       661

[8] Evaluating sample sentences...
  PPL:   45.00 | मैं स्कूल जाता हूँ।
  PPL:   25.60 | आज मौसम बहुत अच्छा है।
  PPL:   24.33 | शिक्षा जीवन का आधार है।
  PPL:    8.76 | भारतीय संविधान दुनिया का सबसे बड़ा लिखित संविधान ह
  PPL:   13.63 | जलवायु परिवर्तन एक गंभीर समस्या है।

✓ Average sample perplexity: 23.46

EVALUATION SUMMARY
✓ Checkpoint:           /home/krrish/Desktop/Programming/slm-distill/outputs/final_model
✓ Model s