# Word-Class Constraint Audit

**Better alternative to POS tagging** (works with BPE tokenization)

## Research Question
Does syntactic structure shift probability mass toward word-initial tokens (vs. fragments/punctuation)?

## Prediction
After determiners (e.g., "the"):
- **Sentence/Jabberwocky**: HIGH % word-start (structure intact)
- **Scrambled**: LOWER % word-start (structure disrupted)

## Step 1: Install Dependencies

In [None]:
!pip install -q transformers torch numpy

## Step 2: Upload Stimuli

Click the **folder icon** on the left and drag `stimuli_with_scrambled.json` into the Files area.

Or run this cell:

In [None]:
from google.colab import files
uploaded = files.upload()  # Upload stimuli_with_scrambled.json

## Step 3: Define Token Classification

In [None]:
def classify_token(token_str):
    """
    Classify GPT-2 BPE token as:
    - 'word_start': Begins a new word (space + letter)
    - 'punctuation': Punctuation marks  
    - 'fragment': Mid-word continuation
    """
    if not token_str:
        return 'fragment'
    
    # Word-initial: space + letter
    if token_str[0] == ' ' and len(token_str) > 1 and token_str[1].isalpha():
        return 'word_start'
    
    # Punctuation
    if token_str.strip() in '.,!?;:"\'\-':
        return 'punctuation'
    
    # Everything else is a fragment
    return 'fragment'

## Step 4: Load Model

In [None]:
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

print("Loading GPT-2...")
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval()
print("✓ Model loaded\n")

print("Loading stimuli...")
with open('stimuli_with_scrambled.json') as f:
    stimuli = json.load(f)
print(f"✓ Loaded {len(stimuli)} stimulus sets\n")

## Step 5: Sanity Test - Verify Classification Works

In [None]:
print("SANITY TEST: Token Classification")
print("-" * 60)

test_cases = [
    (' cat', 'word_start', 'Space + letter → word_start'),
    (' dog', 'word_start', 'Space + letter → word_start'),
    (',', 'punctuation', 'Comma → punctuation'),
    ('.', 'punctuation', 'Period → punctuation'),
    ('ing', 'fragment', 'No space → fragment'),
    ('ed', 'fragment', 'No space → fragment'),
]

all_passed = True
for token, expected, desc in test_cases:
    result = classify_token(token)
    status = "✓" if result == expected else "✗ FAIL"
    if result != expected:
        all_passed = False
    print(f"{status}  {repr(token):12s} → {result:12s}  ({desc})")

print()
if all_passed:
    print("✓ All sanity tests passed!")
else:
    print("✗ Some tests FAILED - check classification logic")
print()

## Step 6: Define Analysis Function

In [None]:
def analyze_position(model, tokenizer, text, target_position, k=100):
    """
    Analyze token predictions at a specific word position.
    Returns probability mass on word_start, punctuation, and fragment classes.
    """
    words = text.split()
    context = ' '.join(words[:target_position+1])
    inputs = tokenizer(context, return_tensors='pt')
    
    # Get top-k predictions
    with torch.no_grad():
        outputs = model(**inputs)
        next_token_logits = outputs.logits[0, -1, :]
        probs = torch.softmax(next_token_logits, dim=-1)
        top_k_probs, top_k_ids = torch.topk(probs, k)
    
    # Classify and accumulate
    class_probs = {'word_start': 0.0, 'punctuation': 0.0, 'fragment': 0.0}
    candidates = []
    
    for prob, token_id in zip(top_k_probs, top_k_ids):
        token_str = tokenizer.decode([token_id])
        token_class = classify_token(token_str)
        
        class_probs[token_class] += prob.item()
        candidates.append({
            'token': repr(token_str),
            'class': token_class,
            'prob': prob.item()
        })
    
    # Normalize to percentage
    total = sum(class_probs.values())
    class_pcts = {k: (v/total)*100 if total > 0 else 0.0 for k, v in class_probs.items()}
    
    return {
        'class_percentages': class_pcts,
        'top_10': candidates[:10]
    }

def find_cue_position(text, cue_word):
    """Find first occurrence of cue word."""
    words = text.split()
    for i, word in enumerate(words[:-1]):
        if word.lower().strip('.,!?;:') == cue_word.lower():
            return i
    return None

## Step 7: Run Analysis on First 5 Stimuli

In [None]:
print("="*80)
print("WORD-CLASS CONSTRAINT AUDIT")
print("="*80)
print()
print("Analysis: Predictions after 'the' (determiner)")
print()

results_by_condition = {
    'sentence': [],
    'jabberwocky_matched': [],
    'scrambled_jabberwocky': []
}

for stim_idx in range(min(5, len(stimuli))):
    stim = stimuli[stim_idx]
    
    print(f"\nSTIMULUS {stim_idx + 1}:")
    print("-" * 80)
    
    for condition in ['sentence', 'jabberwocky_matched', 'scrambled_jabberwocky']:
        text = stim[condition]
        cue_pos = find_cue_position(text, 'the')
        
        if cue_pos is None:
            print(f"\n{condition:25s}: (No 'the' found)")
            continue
        
        result = analyze_position(model, tokenizer, text, cue_pos, k=100)
        
        # Store for summary
        results_by_condition[condition].append(
            result['class_percentages']['word_start']
        )
        
        print(f"\n{condition.upper()}:")
        print(f"  Text: {text[:60]}...")
        print(f"  Class distribution:")
        print(f"    Word-start:    {result['class_percentages']['word_start']:5.1f}%")
        print(f"    Punctuation:   {result['class_percentages']['punctuation']:5.1f}%")
        print(f"    Fragments:     {result['class_percentages']['fragment']:5.1f}%")
        print(f"  Top-10 predictions:")
        for cand in result['top_10']:
            print(f"    {cand['token']:20s} [{cand['class']:12s}] {cand['prob']*100:5.1f}%")

## Step 8: Summary Statistics

In [None]:
print("\n" + "="*80)
print("SUMMARY: % Probability on Word-Initial Tokens (after 'the')")
print("="*80)
print()

summary_stats = {}
for condition in ['sentence', 'jabberwocky_matched', 'scrambled_jabberwocky']:
    pcts = results_by_condition[condition]
    if pcts:
        mean_pct = np.mean(pcts)
        std_pct = np.std(pcts)
        summary_stats[condition] = {'mean': mean_pct, 'std': std_pct, 'n': len(pcts)}
        print(f"{condition:25s}: {mean_pct:5.1f}% ± {std_pct:4.1f}%  (n={len(pcts)})")

print()
print("="*80)
print("INTERPRETATION:")
print("="*80)
print()

jab_mean = summary_stats.get('jabberwocky_matched', {}).get('mean', 0)
scr_mean = summary_stats.get('scrambled_jabberwocky', {}).get('mean', 0)
delta = jab_mean - scr_mean

print(f"Δ (Jabberwocky - Scrambled): {delta:+.1f}%")
print()

if delta > 5:
    print("✓ STRUCTURE CONSTRAINS DISTRIBUTION")
    print()
    print("  Jabberwocky shows higher word-start probability, demonstrating that")
    print("  syntactic structure (even with nonsense words) narrows the continuation")
    print("  space toward word-initial positions.")
    print()
    print("  This confirms category-level constraint beyond lexical co-occurrence.")
else:
    print("✗ Weak effect: Similar word-start probabilities across conditions.")
    print("  Structure may not be strongly constraining the distribution.")

print("\n" + "="*80)

## Step 9: Full Analysis (All 30 Stimuli)

In [None]:
print("Running full analysis on all stimuli...")
print("(This may take 2-3 minutes)\n")

full_results = {
    'sentence': [],
    'jabberwocky_matched': [],
    'scrambled_jabberwocky': []
}

for stim_idx, stim in enumerate(stimuli):
    if (stim_idx + 1) % 5 == 0:
        print(f"  Processing stimulus {stim_idx + 1}/{len(stimuli)}...")
    
    for condition in ['sentence', 'jabberwocky_matched', 'scrambled_jabberwocky']:
        text = stim[condition]
        cue_pos = find_cue_position(text, 'the')
        
        if cue_pos is not None:
            result = analyze_position(model, tokenizer, text, cue_pos, k=100)
            full_results[condition].append({
                'stimulus_idx': stim_idx,
                'word_start_pct': result['class_percentages']['word_start'],
                'punctuation_pct': result['class_percentages']['punctuation'],
                'fragment_pct': result['class_percentages']['fragment']
            })

print("\n✓ Analysis complete!\n")

# Final summary
print("="*80)
print("FINAL SUMMARY (All Stimuli)")
print("="*80)
print()

for condition in ['sentence', 'jabberwocky_matched', 'scrambled_jabberwocky']:
    pcts = [r['word_start_pct'] for r in full_results[condition]]
    if pcts:
        print(f"{condition:25s}: {np.mean(pcts):5.1f}% ± {np.std(pcts):4.1f}%  (n={len(pcts)})")

print()
jab_full = [r['word_start_pct'] for r in full_results['jabberwocky_matched']]
scr_full = [r['word_start_pct'] for r in full_results['scrambled_jabberwocky']]

if jab_full and scr_full:
    delta_full = np.mean(jab_full) - np.mean(scr_full)
    print(f"Δ (Jabberwocky - Scrambled): {delta_full:+.1f}%")
    print()
    
    # Statistical test
    from scipy import stats
    t_stat, p_value = stats.ttest_ind(jab_full, scr_full)
    print(f"t-test: t = {t_stat:.3f}, p = {p_value:.4f}")
    
    if p_value < 0.05:
        print("\n✓ Statistically significant difference (p < 0.05)")
    else:
        print("\n✗ Not statistically significant (p >= 0.05)")

print("\n" + "="*80)

## Step 10: Download Results

In [None]:
# Save to JSON
output = {
    'analysis_type': 'word_class_constraint',
    'model': 'gpt2',
    'cue_word': 'the',
    'top_k': 100,
    'results': full_results,
    'summary': {
        condition: {
            'mean_word_start_pct': float(np.mean([r['word_start_pct'] for r in full_results[condition]])),
            'std_word_start_pct': float(np.std([r['word_start_pct'] for r in full_results[condition]])),
            'n': len(full_results[condition])
        }
        for condition in ['sentence', 'jabberwocky_matched', 'scrambled_jabberwocky']
        if full_results[condition]
    }
}

with open('word_class_results.json', 'w') as f:
    json.dump(output, f, indent=2)

print("✓ Results saved to word_class_results.json")

# Download
from google.colab import files
files.download('word_class_results.json')
print("✓ Download started!")