## 1. Import Libraries and Load Data

In [1]:
import pandas as pd
import spacy
from spacy.tokens import DocBin
from tqdm.auto import tqdm
import numpy as np
from collections import defaultdict, Counter
import math

# For MTLD
from lexicalrichness import LexicalRichness

In [2]:
# Load the spaCy model
nlp = spacy.load("en_core_web_lg")

# Load the CLC-FCE docbins
docbin_path_original = "../data/clc-fce-docbins/original.docbin"
docbin_path_corrected = "../data/clc-fce-docbins/corrected.docbin"

print("Loading original docbin...")
docbin_original = DocBin().from_disk(docbin_path_original)
docs_original = list(docbin_original.get_docs(nlp.vocab))
print(f"Loaded {len(docs_original)} original documents")

print("Loading corrected docbin...")
docbin_corrected = DocBin().from_disk(docbin_path_corrected)
docs_corrected = list(docbin_corrected.get_docs(nlp.vocab))
print(f"Loaded {len(docs_corrected)} corrected documents")

assert len(docs_original) == len(docs_corrected), "Docbins must have same number of docs"

Loading original docbin...
Loaded 2482 original documents
Loading corrected docbin...
Loaded 2482 corrected documents


In [3]:
# Load reference data
print("Loading Slim Pajama token frequencies...")
token_freq_df = pd.read_parquet("../data/slim_pajama_lists/3grams.parquet")

# Need to sum counts over token_2 to get unigram frequencies
token_freq_df = token_freq_df.groupby('token_2', as_index=False)['count'].sum()
token_freq = dict(zip(token_freq_df['token_2'], token_freq_df['count']))
total_tokens = sum(token_freq.values())
print(f"Loaded {len(token_freq)} unique tokens, total {total_tokens} tokens")

print("Loading dependency bigrams...")
dep_df = pd.read_parquet("../data/slim_pajama_lists/depgrams.parquet")
print(f"Loaded {len(dep_df)} dependency bigrams")

Loading Slim Pajama token frequencies...
Loaded 2024311 unique tokens, total 311758665 tokens
Loading dependency bigrams...
Loaded 46189680 dependency bigrams, total 306325736 dependencies


In [6]:
class MiCalculator:
    def __init__(self, reference_grams: pd.DataFrame):
        # Build dep_counts directly using zip - much faster than set_index
        self.dep_counts = dict(zip(
            zip(reference_grams['head_lemma'], 
                reference_grams['dependent_lemma'], 
                reference_grams['relation']),
            reference_grams['count']
        ))
        self.head_marginals = reference_grams.groupby('head_lemma')['count'].sum().to_dict()
        self.dep_marginals = reference_grams.groupby('dependent_lemma')['count'].sum().to_dict()
        self.total_deps = reference_grams['count'].sum()

    def __call__(self, doc) -> dict:
        # Calculate Mutual Information (MI) for dependency relations
        # MI measures how predictable a dependent word is given its head word
        # MI = log2(P(head,dep) / (P(head) * P(dep)))
        # Values > 0 indicate the words co-occur more than expected by chance
        rel_mis = defaultdict(list)
        for token in doc:
            if token.dep_ in {'amod', 'advmod', 'dobj'}:
                head_lemma = token.head.lemma_.lower()
                dep_lemma = token.lemma_.lower()
                relation = token.dep_
                pair = (head_lemma, dep_lemma, relation)
                
                # Get joint count P(head, dep) from reference corpus
                joint_count = self.dep_counts.get(pair, 0)
                if joint_count == 0:
                    continue
                
                # Calculate probabilities from reference corpus
                p_xy = joint_count / self.total_deps  # Joint probability
                p_x = self.head_marginals.get(head_lemma, 0) / self.total_deps  # P(head)
                p_y = self.dep_marginals.get(dep_lemma, 0) / self.total_deps  # P(dependent)
                
                # MI = log2(P(x,y) / P(x)*P(y))
                # log2 gives results in bits; positive values = words co-occur more than expected
                mi = math.log2(p_xy / (p_x * p_y))
                rel_mis[relation].append(mi)
        
        # Average MI per relation
        avg_mis = {f'{rel}': np.mean(mis) if mis else np.nan for rel, mis in rel_mis.items()}
        return avg_mis

mi_calculator = MiCalculator(dep_df)

## 2. Define Helper Functions

In [7]:
# Error type mapping based on FLAN tagset
# Grammar: Agreement (G*), Argument Structure (S), Derivation (D*), Form (F*), 
#          Inflection (I* except ID), Missing (M*), Question (QL), Replacement (R*),
#          Tense (TV), Unnecessary (U*), Word Order/Negation (W, X), Countability (C[DN])
# Vocab: Collocation (CL), Compound (CE), Register (L), Idiom (ID)
# Spelling: Spelling errors (S, SX)
error_mapping = {
    'grammar': [
        # Agreement
        'G', 'GD', 'GN', 'GQ', 'GV',
        # Argument Structure
        'S',
        # Derivation
        'D', 'DC', 'DD', 'DI', 'DJ', 'DN', 'DQ', 'DT', 'DV', 'DY',
        # Form
        'F', 'FD', 'FJ', 'FN', 'FQ', 'FV', 'FY',
        # Inflection (excluding ID - Idiom)
        'I', 'IJ', 'IN', 'IQ', 'IV', 'IY',
        # Missing Elements
        'M', 'MC', 'MD', 'MJ', 'MN', 'MP', 'MQ', 'MT', 'MV', 'MY',
        # Question Errors
        'QL',
        # Replacement Errors
        'R', 'RC', 'RD', 'RJ', 'RN', 'RP', 'RQ', 'RT', 'RV', 'RY',
        # Tense
        'TV',
        # Unnecessary Elements
        'U', 'UC', 'UD', 'UJ', 'UN', 'UP', 'UQ', 'UT', 'UV', 'UY',
        # Word Order and Negation
        'W', 'X',
        # Countability (grammatical aspect)
        'CD', 'CN', 'CQ',
    ],
    'vocab': [
        # Collocation
        'CL',
        # Compound error
        'CE',
        # Register
        'L',
        # Idiom
        'ID',
    ],
    'spelling': [
        # Spelling errors
        'S', 'SX'
    ]
}
# Flatten for lookup
error_type_to_cat = {}
for cat, types in error_mapping.items():
    for t in types:
        error_type_to_cat[t] = cat

def count_errors(doc) -> dict:
    error_counts = {'error_grammar': 0, 'error_vocab': 0, 'error_spelling': 0}
    if 'errors' in doc.spans:
        for span in doc.spans['errors']:
            cat = error_type_to_cat.get(span.label_, 'other')  # Default to other if not mapped
            if not cat == 'other':
                error_counts[f'error_{cat}'] += 1
    return error_counts

def calculate_mtld(tokens) -> float:
    # Simple MTLD implementation (standard parameters: factor=0.72)
    if len(tokens) < 10:
        return np.nan
    token_str = ' '.join(tokens)
    lex = LexicalRichness(token_str)
    return lex.mtld(threshold=0.72)

def count_tunits(doc) -> int:
    # T-units: each main verb (ROOT) and attached clauses
    # For simplicity, count number of ROOT tokens (one per independent clause)
    return sum(1 for token in doc if token.dep_ == 'ROOT')

def lexical_density(doc) -> float:
    content_pos = {'NOUN', 'VERB', 'ADJ', 'ADV'}
    words = [t for t in doc if t.is_alpha]
    if not words:
        return 0
    content_words = [t for t in words if t.pos_ in content_pos]
    return len(content_words) / len(words)

def avg_token_freq(doc, token_freq):
    freqs = []
    for token in doc:
        if token.is_alpha:
            freq = token_freq.get(token.lemma_.lower(), 1)  # Default to 1 if not found
            freqs.append(freq)
    return np.mean(freqs) if freqs else np.nan

def mod_per_nom(doc) -> float:
    nominals = [t for t in doc if t.pos_ == 'NOUN']
    if not nominals:
        return 0
    total_mods = 0
    for nom in nominals:
        # Modifiers: adjectives, determiners, etc. (children with dep amod, det, etc.)
        mods = [c for c in nom.children if c.dep_ in {'amod', 'det', 'nummod', 'compound'}]
        total_mods += len(mods)
    return total_mods / len(nominals)

def dep_per_nom(doc) -> float:
    nominals = [t for t in doc if t.pos_ == 'NOUN']
    if not nominals:
        return 0
    total_deps = sum(len(list(nom.children)) for nom in nominals)
    return total_deps / len(nominals)

## 3. Process Documents and Calculate Metrics

In [10]:
def process_docs(docs, label):
    results = []
    
    print(f"Processing {len(docs)} {label} documents...\n")
    
    for idx, doc in tqdm(enumerate(docs), total=len(docs), desc=f"Calculating metrics ({label})"):
        metrics = {'doc_id': idx}
        
        # Basic counts
        words = [t for t in doc if not t.is_punct]
        metrics['word_count'] = len(words)
        metrics['clause_count'] = len(list(doc.sents))
        metrics['tunit_count'] = count_tunits(doc)
        
        # Lexical
        lemmas = [t.lemma_.lower() for t in words if t.is_alpha]
        metrics['MTLD'] = calculate_mtld(lemmas)
        metrics['lexical_density'] = lexical_density(doc)
        metrics['token_freq'] = avg_token_freq(doc, token_freq)
        
        # Syntactic
        metrics['clauses_per_tunit'] = metrics['clause_count'] / metrics['tunit_count'] if metrics['tunit_count'] > 0 else np.nan
        metrics['mod_per_nom'] = mod_per_nom(doc)
        metrics['dep_per_nom'] = dep_per_nom(doc)
        
        # MI for relations
        mi_dict = mi_calculator(doc)
        metrics.update({k: v for k, v in mi_dict.items()})
        
        # Errors
        error_dict = count_errors(doc)
        metrics.update(error_dict)
        
        results.append(metrics)
    
    return pd.DataFrame(results)

# Process original docs
df_original = process_docs(docs_original, "original")

# Process corrected docs
df_corrected = process_docs(docs_corrected, "corrected")

print("\n✓ Processing complete!")

Processing 2482 original documents...



Calculating metrics (original):   0%|          | 0/2482 [00:00<?, ?it/s]

Processing 2482 corrected documents...



Calculating metrics (corrected):   0%|          | 0/2482 [00:00<?, ?it/s]


✓ Processing complete!


## 4. Save Results

In [12]:
# Display summary statistics
print("Summary statistics for original metrics:\n")
display(df_original.describe())

print("\nSummary statistics for corrected metrics:\n")
display(df_corrected.describe())

Summary statistics for original metrics:



Unnamed: 0,doc_id,word_count,clause_count,tunit_count,MTLD,lexical_density,token_freq,clauses_per_tunit,mod_per_nom,dep_per_nom,amod,dobj,advmod,error_grammar,error_vocab,error_spelling
count,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2461.0,2477.0,2480.0,2482.0,2482.0,2482.0
mean,1240.5,198.40975,15.213135,12.65552,57.868256,0.434486,2309365.0,1.209515,0.753516,1.339965,-2.428705,-6.683092,-5.394641,16.870669,0.210314,2.191378
std,716.636007,41.823723,5.049463,4.313794,13.753058,0.045443,411557.0,0.121447,0.174122,0.237893,2.02775,1.30197,1.940378,9.050966,0.490217,2.899228
min,0.0,41.0,2.0,1.0,26.23755,0.317647,1162169.0,1.0,0.0,0.576923,-10.063218,-12.842543,-14.831642,0.0,0.0,0.0
25%,620.25,172.0,12.0,10.0,47.915458,0.402608,2022137.0,1.125,0.631579,1.179615,-3.76245,-7.5041,-6.686443,10.0,0.0,0.0
50%,1240.5,192.0,15.0,12.0,56.671566,0.429509,2281655.0,1.2,0.75,1.333333,-2.46044,-6.673534,-5.441833,15.0,0.0,1.0
75%,1860.75,219.0,18.0,15.0,66.063699,0.463158,2575411.0,1.285714,0.87234,1.5,-1.144897,-5.861087,-4.240184,22.0,0.0,3.0
max,2481.0,532.0,43.0,41.0,132.865385,0.621429,3711263.0,2.0,1.333333,2.304348,6.735806,0.102862,6.33222,87.0,3.0,54.0



Summary statistics for corrected metrics:



Unnamed: 0,doc_id,word_count,clause_count,tunit_count,MTLD,lexical_density,token_freq,clauses_per_tunit,mod_per_nom,dep_per_nom,dobj,amod,advmod,error_grammar,error_vocab,error_spelling
count,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2482.0,2481.0,2464.0,2482.0,2482.0,2482.0,2482.0
mean,1240.5,198.599114,15.776793,13.079371,57.23472,0.432555,2338102.0,1.212323,0.772203,1.361014,-6.588163,-2.235374,-5.191133,0.0,0.0,0.0
std,716.636007,41.972238,5.01004,4.24375,13.497876,0.044743,397766.1,0.116696,0.16661,0.231511,1.244288,1.999179,1.863136,0.0,0.0,0.0
min,0.0,41.0,2.0,2.0,26.89392,0.306358,1115005.0,1.0,0.0,0.625,-13.99281,-10.063218,-11.421928,0.0,0.0,0.0
25%,620.25,172.25,12.0,10.0,47.426844,0.4,2062163.0,1.133333,0.65625,1.208576,-7.369549,-3.555785,-6.448256,0.0,0.0,0.0
50%,1240.5,192.0,15.0,13.0,55.746972,0.427443,2305995.0,1.2,0.768336,1.357143,-6.599329,-2.249429,-5.199328,0.0,0.0,0.0
75%,1860.75,219.0,19.0,15.0,65.333333,0.461467,2596182.0,1.285714,0.888889,1.5,-5.794678,-0.986569,-4.06228,0.0,0.0,0.0
max,2481.0,540.0,47.0,41.0,118.349191,0.625,3874952.0,2.0,1.36,2.304348,-1.356387,6.393348,5.958829,0.0,0.0,0.0


In [13]:
# Save to files
output_path_original = "../data/clc_fce_metrics_original.csv"
output_path_corrected = "../data/clc_fce_metrics_corrected.csv"

df_original.to_csv(output_path_original, index=False)
df_corrected.to_csv(output_path_corrected, index=False)

print(f"✓ Original results saved to: {output_path_original}")
print(f"✓ Corrected results saved to: {output_path_corrected}")
print(f"  Total docs: {len(df_original)}")

✓ Original results saved to: ../data/clc_fce_metrics_original.csv
✓ Corrected results saved to: ../data/clc_fce_metrics_corrected.csv
  Total docs: 2482


In [14]:
# Save combined file
df_original = pd.read_csv(output_path_original)
df_corrected = pd.read_csv(output_path_corrected)

predictability_original = pd.read_csv("../data/clc_fce_predictability_original.csv")
predictability_corrected = pd.read_csv("../data/clc_fce_predictability_corrected.csv")

df_combined_original = pd.merge(df_original, predictability_original, on='doc_id')
df_combined_corrected = pd.merge(df_corrected, predictability_corrected, on='doc_id')

output_path_combined_original = "../data/clc_fce_metrics_predictability_original.csv"
output_path_combined_corrected = "../data/clc_fce_metrics_predictability_corrected.csv"

df_combined_original.to_csv(output_path_combined_original, index=False)
df_combined_corrected.to_csv(output_path_combined_corrected, index=False)

print(f"✓ Combined original results saved to: {output_path_combined_original}")
print(f"✓ Combined corrected results saved to: {output_path_combined_corrected}")

✓ Combined original results saved to: ../data/clc_fce_metrics_predictability_original.csv
✓ Combined corrected results saved to: ../data/clc_fce_metrics_predictability_corrected.csv
