## 1. Import Libraries and Load Data

In [2]:
import pandas as pd
import spacy
from transformers import AutoTokenizer, AutoModelForMaskedLM
from predictor import Predictor
from tqdm.auto import tqdm
import torch
from spacy.tokens import DocBin

# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(
        f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
    )
    torch.set_float32_matmul_precision('high')

CUDA available: True
GPU: NVIDIA RTX A6000
GPU Memory: 50.93 GB


In [6]:
nlp = spacy.load("en_core_web_lg")

# Load the CLC-FCE docbins
docbin_path_original = "../data/clc-fce-docbins/original.docbin"
docbin_path_corrected = "../data/clc-fce-docbins/corrected.docbin"

print("Loading original docbin...")
docbin_original = DocBin().from_disk(docbin_path_original)
docs_original = list(docbin_original.get_docs(nlp.vocab))
print(f"Loaded {len(docs_original)} original documents")

print("Loading corrected docbin...")
docbin_corrected = DocBin().from_disk(docbin_path_corrected)
docs_corrected = list(docbin_corrected.get_docs(nlp.vocab))
print(f"Loaded {len(docs_corrected)} corrected documents")

assert len(docs_original) == len(docs_corrected), "Docbins must have same number of docs"

Loading original docbin...
Loaded 2482 original documents
Loading corrected docbin...
Loaded 2482 corrected documents


## 2. Load Models

In [7]:
# Load ModernBERT model and tokenizer
model_name = "answerdotai/ModernBERT-base"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

print("✓ ModernBERT loaded")

Loading answerdotai/ModernBERT-base...
✓ ModernBERT loaded


In [8]:
# Initialize predictor
device = "cuda" if torch.cuda.is_available() else "cpu"
window_size = 64
batch_size = 32

predictor = Predictor(
    tokenizer=tokenizer,
    model=model,
    model_type="masked",
    batch_size=batch_size,
    device=device,
)

print("Predictor initialized:")
print("  Model type: masked")
print(f"  Window size: {window_size}")
print(f"  Batch size: {batch_size}")
print(f"  Device: {device}")

Predictor initialized:
  Model type: masked
  Window size: 64
  Batch size: 32
  Device: cuda


## 3. Process Documents and Calculate Predictability

In [9]:
def process_docs(docs, label):
    results = {"doc_id": [], "mean_loss": [], "mean_prob": [], "mean_entropy": []}
    
    print(f"Processing {len(docs)} {label} documents...\n")
    
    for idx, doc in tqdm(enumerate(docs), total=len(docs), desc=f"Calculating predictability ({label})"):
        text = doc.text
        
        # Skip if text is empty
        if not text.strip():
            results["doc_id"].append(idx)
            results["mean_loss"].append(None)
            results["mean_prob"].append(None)
            results["mean_entropy"].append(None)
            continue
        
        try:
            # Calculate predictability
            doc_pred = predictor(doc, window_size=window_size)
            
            # Store aggregate metrics
            results["doc_id"].append(idx)
            results["mean_loss"].append(doc_pred.mean_loss)
            results["mean_prob"].append(
                sum(t.mean_prob for t in doc_pred) / len(doc_pred)
            )
            results["mean_entropy"].append(doc_pred.mean_entropy)
            
        except Exception as e:
            print(f"\nError processing doc {idx}: {e}")
            results["doc_id"].append(idx)
            results["mean_loss"].append(None)
            results["mean_prob"].append(None)
            results["mean_entropy"].append(None)
    
    return pd.DataFrame(results)

# Process original docs
df_original = process_docs(docs_original, "original")

# Process corrected docs
df_corrected = process_docs(docs_corrected, "corrected")

print("\n✓ Processing complete!")

Processing 2482 original documents...



Calculating predictability (original): 100%|██████████| 2482/2482 [12:21<00:00,  3.35it/s]


Processing 2482 corrected documents...



Calculating predictability (corrected): 100%|██████████| 2482/2482 [11:33<00:00,  3.58it/s]


✓ Processing complete!





## 4. Save Results

In [10]:
# Display summary statistics
print("Summary statistics for original predictability metrics:\n")
print(df_original[["mean_loss", "mean_prob", "mean_entropy"]].describe())

print("\nSummary statistics for corrected predictability metrics:\n")
print(df_corrected[["mean_loss", "mean_prob", "mean_entropy"]].describe())

Summary statistics for original predictability metrics:

         mean_loss    mean_prob  mean_entropy
count  2482.000000  2482.000000   2482.000000
mean      1.706858     0.528374      1.646216
std       0.399799     0.068585      0.337173
min       0.700076     0.140778      0.851641
25%       1.423229     0.485645      1.413056
50%       1.665526     0.534930      1.593374
75%       1.936759     0.575824      1.833408
max       4.844044     0.721060      4.581869

Summary statistics for corrected predictability metrics:

         mean_loss    mean_prob  mean_entropy
count  2482.000000  2482.000000   2482.000000
mean      1.247179     0.601502      1.332564
std       0.253283     0.050147      0.222421
min       0.557426     0.338495      0.742476
25%       1.077096     0.572559      1.182228
50%       1.222371     0.605499      1.307154
75%       1.377344     0.634809      1.441542
max       2.777417     0.753488      2.744214


In [11]:
# Save to files
output_path_original = "../data/clc_fce_predictability_original.csv"
output_path_corrected = "../data/clc_fce_predictability_corrected.csv"

df_original.to_csv(output_path_original, index=False)
df_corrected.to_csv(output_path_corrected, index=False)

print(f"✓ Original results saved to: {output_path_original}")
print(f"✓ Corrected results saved to: {output_path_corrected}")
print(f"  Total docs: {len(df_original)}")

✓ Original results saved to: ../data/clc_fce_predictability_original.csv
✓ Corrected results saved to: ../data/clc_fce_predictability_corrected.csv
  Total docs: 2482


## 5. Quick Validation

In [12]:
# Check for any missing values
print("Missing values in original:")
print(df_original[["mean_loss", "mean_prob", "mean_entropy"]].isna().sum())

print("\nMissing values in corrected:")
print(df_corrected[["mean_loss", "mean_prob", "mean_entropy"]].isna().sum())

Missing values in original:
mean_loss       0
mean_prob       0
mean_entropy    0
dtype: int64

Missing values in corrected:
mean_loss       0
mean_prob       0
mean_entropy    0
dtype: int64
