## 1. Import Libraries and Load Data

In [None]:
import pandas as pd
import spacy
from spacy.tokens import DocBin
from tqdm.auto import tqdm
import numpy as np
from collections import defaultdict, Counter
import math

# For MTLD
from lexicalrichness import LexicalRichness

In [5]:
# Load the spaCy model
nlp = spacy.load("en_core_web_lg")

# Load the CLC-FCE docbins
docbin_path_original = "../data/clc-fce-docbins/original.docbin"
docbin_path_corrected = "../data/clc-fce-docbins/corrected.docbin"

print("Loading original docbin...")
docbin_original = DocBin().from_disk(docbin_path_original)
docs_original = list(docbin_original.get_docs(nlp.vocab))
print(f"Loaded {len(docs_original)} original documents")

print("Loading corrected docbin...")
docbin_corrected = DocBin().from_disk(docbin_path_corrected)
docs_corrected = list(docbin_corrected.get_docs(nlp.vocab))
print(f"Loaded {len(docs_corrected)} corrected documents")

assert len(docs_original) == len(docs_corrected), "Docbins must have same number of docs"

Loading original docbin...
Loaded 2482 original documents
Loading corrected docbin...
Loaded 2482 corrected documents


In [7]:
# Load reference data
print("Loading Slim Pajama token frequencies...")
token_freq_df = pd.read_parquet("../data/slim_pajama_lists/3grams.parquet")

# Need to sum counts over token_2 to get unigram frequencies
token_freq_df = token_freq_df.groupby('token_2', as_index=False)['count'].sum()
token_freq = dict(zip(token_freq_df['token_2'], token_freq_df['count']))
total_tokens = sum(token_freq.values())
print(f"Loaded {len(token_freq)} unique tokens, total {total_tokens} tokens")

print("Loading dependency bigrams...")
dep_df = pd.read_parquet("../data/slim_pajama_lists/depgrams.parquet")
dep_counts = dep_df.set_index(['head_lemma', 'dependent_lemma', 'relation'])['count'].to_dict()
head_marginals = dep_df.groupby('head_lemma')['count'].sum().to_dict()
dep_marginals = dep_df.groupby('dependent_lemma')['count'].sum().to_dict()
total_deps = dep_df['count'].sum()
print(f"Loaded {len(dep_counts)} dependency bigrams, total {total_deps} dependencies")

Loading Slim Pajama token frequencies...
Loaded 2024311 unique tokens, total 311758665 tokens
Loading dependency bigrams...
Loaded 46189680 dependency bigrams, total 306325736 dependencies


## 2. Define Helper Functions

In [18]:
# Error type mapping (reasonable decisions based on description)
# Spelling: S, SA, SX
# Grammar: Morphological (AG, C, D, F, I), POS edits (M/R/U + POS), AS, CE, W, X, TV
# Vocab: CL, ID, L, QL
error_mapping = {
    'grammar': ['AG', 'C', 'D', 'F', 'I', 'MA', 'MR', 'MU', 'MC', 'MD', 'MF', 'MI', 'MJ', 'MN', 'MP', 'MQ', 'MT', 'MV', 'MY', 'RA', 'RR', 'RU', 'RC', 'RD', 'RF', 'RI', 'RJ', 'RN', 'RP', 'RQ', 'RT', 'RV', 'RY', 'UA', 'UR', 'UU', 'UC', 'UD', 'UF', 'UI', 'UJ', 'UN', 'UP', 'UQ', 'UT', 'UV', 'UY', 'AS', 'CE', 'W', 'X', 'TV'],
    'vocab': ['CL', 'ID', 'L', 'QL'],
    'spelling': ['S', 'SA', 'SX']
}
# Flatten for lookup
error_type_to_cat = {}
for cat, types in error_mapping.items():
    for t in types:
        error_type_to_cat[t] = cat

def count_errors(doc):
    error_counts = {'error_grammar': 0, 'error_vocab': 0, 'error_spelling': 0}
    if 'errors' in doc.spans:
        for span in doc.spans['errors']:
            cat = error_type_to_cat.get(span.label_, 'other')  # Default to other if not mapped
            if not cat == 'other':
                error_counts[f'error_{cat}'] += 1
    return error_counts

def calculate_mtld(tokens):
    # Simple MTLD implementation (standard parameters: factor=0.72)
    if len(tokens) < 10:
        return np.nan
    token_str = ' '.join(tokens)
    lex = LexicalRichness(token_str)
    return lex.mtld(threshold=0.72)

def count_tunits(doc):
    # T-units: each main verb (ROOT) and attached clauses
    # For simplicity, count number of ROOT tokens (one per independent clause)
    return sum(1 for token in doc if token.dep_ == 'ROOT')

def lexical_density(doc):
    content_pos = {'NOUN', 'VERB', 'ADJ', 'ADV'}
    words = [t for t in doc if t.is_alpha]
    if not words:
        return 0
    content_words = [t for t in words if t.pos_ in content_pos]
    return len(content_words) / len(words)

def avg_token_freq(doc, token_freq):
    freqs = []
    for token in doc:
        if token.is_alpha:
            freq = token_freq.get(token.lemma_.lower(), 1)  # Default to 1 if not found
            freqs.append(freq)
    return np.mean(freqs) if freqs else np.nan

def mod_per_nom(doc):
    nominals = [t for t in doc if t.pos_ == 'NOUN']
    if not nominals:
        return 0
    total_mods = 0
    for nom in nominals:
        # Modifiers: adjectives, determiners, etc. (children with dep amod, det, etc.)
        mods = [c for c in nom.children if c.dep_ in {'amod', 'det', 'nummod', 'compound'}]
        total_mods += len(mods)
    return total_mods / len(nominals)

def dep_per_nom(doc):
    nominals = [t for t in doc if t.pos_ == 'NOUN']
    if not nominals:
        return 0
    total_deps = sum(len(list(nom.children)) for nom in nominals)
    return total_deps / len(nominals)

def calculate_mi(doc, dep_counts, head_marginals, dep_marginals, total_deps):
    rel_mis = defaultdict(list)
    for token in doc:
        if token.dep_ in {'amod', 'advmod', 'dobj'}:
            head = token.head.lemma_.lower()
            dep = token.lemma_.lower()
            rel = token.dep_
            key = (head, dep, rel)
            joint = dep_counts.get(key, 0)
            if joint > 0:
                p_joint = joint / total_deps
                p_head = head_marginals.get(head, 0) / total_deps
                p_dep = dep_marginals.get(dep, 0) / total_deps
                if p_head > 0 and p_dep > 0:
                    mi = math.log(p_joint / (p_head * p_dep))
                    rel_mis[rel].append(mi)
    # Average MI per relation
    return {f'{rel}': np.mean(mis) if mis else np.nan for rel, mis in rel_mis.items()}

## 3. Process Documents and Calculate Metrics

In [21]:
def process_docs(docs, label):
    results = []
    
    print(f"Processing {len(docs)} {label} documents...\n")
    
    for idx, doc in tqdm(enumerate(docs), total=len(docs), desc=f"Calculating metrics ({label})"):
        metrics = {'doc_id': idx}
        
        # Basic counts
        words = [t for t in doc if not t.is_punct]
        metrics['word_count'] = len(words)
        metrics['clause_count'] = len(list(doc.sents))
        metrics['tunit_count'] = count_tunits(doc)
        
        # Lexical
        lemmas = [t.lemma_.lower() for t in words if t.is_alpha]
        metrics['MTLD'] = calculate_mtld(lemmas)
        metrics['lexical_density'] = lexical_density(doc)
        metrics['token_freq'] = avg_token_freq(doc, token_freq)
        
        # Syntactic
        metrics['clauses_per_tunit'] = metrics['clause_count'] / metrics['tunit_count'] if metrics['tunit_count'] > 0 else np.nan
        metrics['mod_per_nom'] = mod_per_nom(doc)
        metrics['dep_per_nom'] = dep_per_nom(doc)
        
        # MI for relations
        mi_dict = calculate_mi(doc, dep_counts, head_marginals, dep_marginals, total_deps)
        metrics.update({k: v for k, v in mi_dict.items()})
        
        # Errors
        error_dict = count_errors(doc)
        metrics.update(error_dict)
        
        results.append(metrics)
    
    return pd.DataFrame(results)

# Process original docs
df_original = process_docs(docs_original, "original")

# Process corrected docs
df_corrected = process_docs(docs_corrected, "corrected")

print("\n✓ Processing complete!")

Processing 2482 original documents...



Calculating metrics (original): 100%|██████████| 2482/2482 [00:02<00:00, 964.84it/s] 


Processing 2482 corrected documents...



Calculating metrics (corrected): 100%|██████████| 2482/2482 [00:02<00:00, 1069.02it/s]


✓ Processing complete!





## 4. Save Results

In [22]:
# Display summary statistics
print("Summary statistics for original metrics:\n")
print(df_original.describe())

print("\nSummary statistics for corrected metrics:\n")
print(df_corrected.describe())

Summary statistics for original metrics:

            doc_id   word_count  clause_count  tunit_count         MTLD  \
count  2482.000000  2482.000000   2482.000000  2482.000000  2482.000000   
mean   1240.500000   198.409750     15.213135    12.655520    57.868256   
std     716.636007    41.823723      5.049463     4.313794    13.753058   
min       0.000000    41.000000      2.000000     1.000000    26.237550   
25%     620.250000   172.000000     12.000000    10.000000    47.915458   
50%    1240.500000   192.000000     15.000000    12.000000    56.671566   
75%    1860.750000   219.000000     18.000000    15.000000    66.063699   
max    2481.000000   532.000000     43.000000    41.000000   132.865385   

       lexical_density    token_freq  clauses_per_tunit  mod_per_nom  \
count      2482.000000  2.482000e+03        2482.000000  2482.000000   
mean          0.434486  2.309365e+06           1.209515     0.753516   
std           0.045443  4.115570e+05           0.121447     0.1741

In [23]:
# Save to files
output_path_original = "../data/clc_fce_metrics_original.csv"
output_path_corrected = "../data/clc_fce_metrics_corrected.csv"

df_original.to_csv(output_path_original, index=False)
df_corrected.to_csv(output_path_corrected, index=False)

print(f"✓ Original results saved to: {output_path_original}")
print(f"✓ Corrected results saved to: {output_path_corrected}")
print(f"  Total docs: {len(df_original)}")

✓ Original results saved to: ../data/clc_fce_metrics_original.csv
✓ Corrected results saved to: ../data/clc_fce_metrics_corrected.csv
  Total docs: 2482
