# Locked Design Morphosyntax Audit

**Purpose**: Test morphosyntactic constraints across 6 cue families × 6 conditions with integrated context ablation.

**Key Features**:
- 30 dedicated stimuli per cue family (180 total)
- Exactly 1 cue per sentence in controlled position
- Context ablation: k ∈ {1, 2, 4, 8, full}
- FDR-corrected statistical tests
- Publication-ready figures

**Cue Families**:
1. Infinitival TO → VERB
2. Modals → VERB  
3. Determiners → NOUN/ADJ
4. Prepositions → NP_START
5. Auxiliaries → PARTICIPLE
6. Complementizers → CLAUSE_START

## Setup

In [None]:
# Install dependencies
!pip install -q transformers torch pandas matplotlib seaborn scipy tqdm

In [None]:
# Mount Google Drive (optional - for saving results)
from google.colab import drive
drive.mount('/content/drive')

# Set output directory
OUTPUT_DIR = '/content/drive/MyDrive/morphosyntax_locked_results'
import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Results will be saved to: {OUTPUT_DIR}")

In [None]:
import json
import random
import hashlib
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from tqdm.auto import tqdm
from datetime import datetime
from typing import Dict, List, Set, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Stimulus Generation

In [None]:
# Configuration
SEED = 42
N_SENTENCES_PER_FAMILY = 30

# Function words (for slot identification)
FUNCTION_WORDS = {
    'the', 'a', 'an', 'this', 'that', 'these', 'those',
    'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
    'to', 'at', 'in', 'on', 'with', 'from', 'of', 'for', 'by', 'about',
    'and', 'or', 'but', 'so', 'if', 'because', 'that', 'whether',
    'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'has', 'have', 'had', 'having', 'do', 'does', 'did',
    'will', 'would', 'can', 'could', 'shall', 'should', 'may', 'might', 'must',
    'decided', 'wanted', 'began', 'tried', 'planned', 'continued', 'hoped',
    'said', 'thought', 'believed', 'knew', 'expected',
    'saw', 'heard', 'felt', 'noticed',
}

# Word pools
NOUNS_AGENT = [
    'scientist', 'artist', 'teacher', 'student', 'doctor', 'chef', 'engineer',
    'musician', 'author', 'mechanic', 'programmer', 'architect', 'researcher',
    'photographer', 'detective', 'captain', 'professor', 'coach', 'director',
    'farmer', 'baker', 'pilot', 'surgeon', 'lawyer', 'journalist', 'carpenter',
    'plumber', 'electrician', 'gardener', 'librarian'
]

NOUNS_PATIENT = [
    'artifacts', 'paintings', 'concept', 'assignment', 'landscape', 'patient',
    'symphony', 'novel', 'engine', 'software', 'blueprint', 'data', 'experiment',
    'recipe', 'manuscript', 'equipment', 'specimens', 'documents', 'vehicle',
    'instrument', 'formula', 'strategy', 'findings', 'technique', 'method',
    'discovery', 'solution', 'project', 'design', 'prototype'
]

ADJECTIVES = [
    'ancient', 'beautiful', 'difficult', 'important', 'complex', 'delicious',
    'broken', 'efficient', 'innovative', 'comprehensive', 'valuable', 'intricate',
    'challenging', 'experimental', 'preliminary', 'historical', 'modern', 'rare',
    'elaborate', 'fundamental', 'critical', 'remarkable', 'unusual', 'significant',
    'mysterious', 'fascinating', 'practical', 'theoretical', 'advanced', 'basic'
]

VERBS_BASE = [
    'study', 'examine', 'analyze', 'explore', 'investigate', 'review', 'inspect',
    'evaluate', 'assess', 'test', 'repair', 'build', 'design', 'create', 'develop',
    'improve', 'fix', 'complete', 'finish', 'prepare', 'organize', 'document',
    'present', 'explain', 'discuss', 'publish', 'research', 'demonstrate', 'solve', 'process'
]

VERBS_PAST = [
    'decided', 'wanted', 'began', 'tried', 'planned', 'continued', 'hoped',
    'expected', 'needed', 'agreed', 'promised', 'attempted', 'managed', 'struggled',
    'offered', 'refused', 'learned', 'forgot', 'remembered', 'chose', 'liked',
    'loved', 'preferred', 'wished', 'demanded', 'requested', 'intended', 'meant',
    'prepared', 'started'
]

VERBS_PARTICIPLE_ING = [
    'studying', 'examining', 'analyzing', 'exploring', 'investigating', 'reviewing',
    'inspecting', 'evaluating', 'assessing', 'testing', 'repairing', 'building',
    'designing', 'creating', 'developing', 'improving', 'fixing', 'completing',
    'finishing', 'preparing', 'organizing', 'documenting', 'presenting', 'explaining',
    'discussing', 'publishing', 'researching', 'demonstrating', 'solving', 'processing'
]

VERBS_PARTICIPLE_ED = [
    'studied', 'examined', 'analyzed', 'explored', 'investigated', 'reviewed',
    'inspected', 'evaluated', 'assessed', 'tested', 'repaired', 'built',
    'designed', 'created', 'developed', 'improved', 'fixed', 'completed',
    'finished', 'prepared', 'organized', 'documented', 'presented', 'explained',
    'discussed', 'published', 'researched', 'demonstrated', 'solved', 'processed'
]

PREPOSITIONS_LIST = ['with', 'in', 'on', 'at', 'for', 'about']
MODALS_LIST = ['can', 'will', 'could', 'would', 'should', 'must', 'may', 'might']
AUXILIARIES_BE = ['is', 'was', 'are', 'were']

In [None]:
class NonceGenerator:
    """Generate pronounceable nonce words."""
    ONSETS = ['b', 'bl', 'br', 'c', 'cl', 'cr', 'd', 'dr', 'f', 'fl', 'fr',
              'g', 'gl', 'gr', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'pl', 'pr',
              'qu', 'r', 's', 'sc', 'sk', 'sl', 'sm', 'sn', 'sp', 'st', 'str',
              'sw', 't', 'tr', 'th', 'v', 'w', 'wh', 'y', 'z']
    NUCLEI = ['a', 'e', 'i', 'o', 'u', 'ee', 'oo', 'ea', 'ey']
    CODAS = ['', 'b', 'ck', 'd', 'f', 'g', 'k', 'l', 'll', 'm', 'n', 'ng',
             'nk', 'p', 'r', 's', 'sk', 'sp', 'ss', 'st', 't', 'x', 'z']

    def __init__(self, seed=None):
        self.rng = random.Random(seed)
        self.used = set()

    def generate(self) -> str:
        for _ in range(100):
            word = self.rng.choice(self.ONSETS) + self.rng.choice(self.NUCLEI) + self.rng.choice(self.CODAS)
            if word not in FUNCTION_WORDS and word not in self.used and len(word) >= 3:
                self.used.add(word)
                return word
        return f"zx{self.rng.randint(100, 999)}"

def identify_slots(words):
    """Identify function vs content slots."""
    function_slots, content_slots = [], []
    for i, word in enumerate(words):
        if word.lower().strip('.,!?;:') in FUNCTION_WORDS:
            function_slots.append((i, word))
        else:
            content_slots.append((i, word))
    return function_slots, content_slots

def full_scramble(sentence, seed):
    """Shuffle ALL words."""
    words = sentence.split()
    rng = random.Random(seed)
    shuffled = words.copy()
    for _ in range(10):
        rng.shuffle(shuffled)
        if shuffled != words: break
    return ' '.join(shuffled)

def content_scramble(sentence, seed):
    """Shuffle content words among content slots."""
    words = sentence.split()
    func_slots, cont_slots = identify_slots(words)
    cont_words = [w for _, w in cont_slots]
    rng = random.Random(seed)
    shuffled = cont_words.copy()
    for _ in range(10):
        rng.shuffle(shuffled)
        if shuffled != cont_words: break
    result = words.copy()
    for i, (idx, _) in enumerate(cont_slots):
        result[idx] = shuffled[i]
    return ' '.join(result)

def function_scramble(sentence, seed):
    """Shuffle function words among function slots."""
    words = sentence.split()
    func_slots, cont_slots = identify_slots(words)
    func_words = [w for _, w in func_slots]
    rng = random.Random(seed)
    shuffled = func_words.copy()
    for _ in range(10):
        rng.shuffle(shuffled)
        if shuffled != func_words: break
    result = words.copy()
    for i, (idx, _) in enumerate(func_slots):
        result[idx] = shuffled[i]
    return ' '.join(result)

def cue_deletion(sentence, cue_word, replacement='ke'):
    """Replace cue with nonce."""
    words = sentence.split()
    result = []
    found = False
    for w in words:
        if w.lower().strip('.,!?;:') == cue_word.lower() and not found:
            result.append(replacement)
            found = True
        else:
            result.append(w)
    return ' '.join(result)

In [None]:
def generate_all_stimuli():
    """Generate all 180 stimuli (30 per family × 6 families)."""
    rng = random.Random(SEED)
    all_stimuli = []
    
    # Template generators for each family
    families = [
        ('infinitival_to', 'to', 3, 'VERB'),
        ('modals', None, 2, 'VERB'),  # cue varies
        ('determiners', 'a', 5, 'NOUN_OR_ADJ'),
        ('prepositions', None, 3, 'NP_START'),  # cue varies
        ('auxiliaries', None, 2, 'PARTICIPLE'),  # cue varies
        ('complementizers', 'that', 3, 'CLAUSE_START'),
    ]
    
    for family, default_cue, cue_pos, target in families:
        nonce_gen = NonceGenerator(seed=hash(family) % (2**31))
        
        agents = rng.sample(NOUNS_AGENT, 30)
        patients = rng.sample(NOUNS_PATIENT, 30)
        adjs = rng.sample(ADJECTIVES, 30)
        
        for i in range(N_SENTENCES_PER_FAMILY):
            agent = agents[i]
            patient = patients[i]
            adj = adjs[i]
            
            # Generate template based on family
            if family == 'infinitival_to':
                v_past = rng.choice(VERBS_PAST[:15])
                v_base = rng.choice(VERBS_BASE)
                sentence = f"the {agent} {v_past} to {v_base} the {adj} {patient}"
                n_agent, n_vbase, n_adj, n_patient = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {v_past} to {n_vbase} the {n_adj} {n_patient}"
                cue = 'to'
                
            elif family == 'modals':
                modal = rng.choice(MODALS_LIST)
                v_base = rng.choice(VERBS_BASE)
                sentence = f"the {agent} {modal} {v_base} the {adj} {patient}"
                n_agent, n_vbase, n_adj, n_patient = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {modal} {n_vbase} the {n_adj} {n_patient}"
                cue = modal
                
            elif family == 'determiners':
                v_past = rng.choice(VERBS_PAST)
                sentence = f"the {agent} {v_past} and saw a {adj} {patient}"
                n_agent, n_vpast, n_adj, n_patient = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {n_vpast} and saw a {n_adj} {n_patient}"
                cue = 'a'
                
            elif family == 'prepositions':
                v_past = rng.choice(VERBS_PAST)
                prep = rng.choice(PREPOSITIONS_LIST)
                sentence = f"the {agent} {v_past} {prep} the {adj} {patient}"
                n_agent, n_vpast, n_adj, n_patient = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {n_vpast} {prep} the {n_adj} {n_patient}"
                cue = prep
                
            elif family == 'auxiliaries':
                aux = rng.choice(AUXILIARIES_BE)
                if i % 2 == 0:
                    v_part = rng.choice(VERBS_PARTICIPLE_ING)
                else:
                    v_part = rng.choice(VERBS_PARTICIPLE_ED)
                sentence = f"the {agent} {aux} {v_part} the {adj} {patient}"
                n_agent, n_vpart, n_adj, n_patient = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {aux} {n_vpart} the {n_adj} {n_patient}"
                cue = aux
                
            elif family == 'complementizers':
                v1 = rng.choice(['said', 'thought', 'believed', 'knew', 'expected', 'hoped'])
                v2 = rng.choice(VERBS_PAST)
                agent2 = rng.choice(NOUNS_AGENT)
                sentence = f"the {agent} {v1} that the {agent2} {v2}"
                n_agent, n_v1, n_agent2, n_v2 = [nonce_gen.generate() for _ in range(4)]
                jabberwocky = f"the {n_agent} {n_v1} that the {n_agent2} {n_v2}"
                cue = 'that'
            
            # Generate seeds for scrambles
            seed_base = hash(f"{family}_{i}") % (2**31)
            
            # Generate all conditions
            stimulus = {
                'set_id': i + 1,
                'cue_family': family,
                'cue_word': cue,
                'cue_position': cue_pos,
                'target_class': target,
                'sentence': sentence,
                'jabberwocky': jabberwocky,
                'full_scrambled': full_scramble(jabberwocky, seed_base + 1),
                'content_scrambled': content_scramble(jabberwocky, seed_base + 2),
                'function_scrambled': function_scramble(jabberwocky, seed_base + 3),
                'cue_deleted': cue_deletion(jabberwocky, cue),
            }
            all_stimuli.append(stimulus)
    
    return all_stimuli

# Generate stimuli
print("Generating stimuli...")
stimuli = generate_all_stimuli()
print(f"Generated {len(stimuli)} stimuli")

# Show examples
print("\nExamples from each family:")
for family in ['infinitival_to', 'modals', 'determiners', 'prepositions', 'auxiliaries', 'complementizers']:
    s = [x for x in stimuli if x['cue_family'] == family][0]
    print(f"\n{family} (cue='{s['cue_word']}'):")
    print(f"  SENTENCE:     {s['sentence']}")
    print(f"  JABBERWOCKY:  {s['jabberwocky']}")
    print(f"  FUNC_SCRAMB:  {s['function_scrambled']}")

## 2. Target Class Definitions

In [None]:
# Word sets for each target class
VERB_SET = {
    'be', 'have', 'do', 'say', 'go', 'get', 'make', 'know', 'think', 'take',
    'see', 'come', 'want', 'use', 'find', 'give', 'tell', 'work', 'call', 'try',
    'ask', 'need', 'feel', 'become', 'leave', 'put', 'mean', 'keep', 'let', 'begin',
    'seem', 'help', 'show', 'hear', 'play', 'run', 'move', 'live', 'believe',
    'bring', 'happen', 'write', 'sit', 'stand', 'lose', 'pay', 'meet', 'continue',
    'set', 'learn', 'change', 'lead', 'understand', 'watch', 'follow', 'stop', 'create', 'speak',
    'read', 'allow', 'add', 'spend', 'grow', 'open', 'walk', 'win', 'teach', 'offer',
    'remember', 'love', 'consider', 'appear', 'buy', 'serve', 'die', 'send', 'build', 'stay',
    'fall', 'cut', 'reach', 'kill', 'raise', 'pass', 'sell', 'decide', 'return', 'explain',
    'hope', 'develop', 'carry', 'break', 'receive', 'agree', 'support', 'hit', 'produce', 'eat',
    'study', 'research', 'investigate', 'examine', 'analyze', 'explore',
    'paint', 'draw', 'design', 'construct', 'perform', 'practice',
    'publish', 'edit', 'revise', 'prepare', 'cook', 'repair', 'fix',
    'solve', 'calculate', 'improve', 'enhance', 'test', 'validate',
    'organize', 'arrange', 'defend', 'protect', 'film', 'record',
    'sail', 'navigate', 'discuss', 'debate', 'assemble', 'combine',
    'plan', 'schedule', 'finish', 'complete', 'start', 'end',
}

NOUN_SET = {
    'time', 'person', 'year', 'way', 'day', 'thing', 'man', 'world', 'life', 'hand',
    'part', 'child', 'eye', 'woman', 'place', 'work', 'week', 'case', 'point',
    'company', 'number', 'group', 'problem', 'fact', 'house', 'area', 'money', 'story', 'student',
    'word', 'family', 'head', 'water', 'room', 'mother', 'night', 'home', 'side',
    'power', 'hour', 'game', 'line', 'end', 'member', 'law', 'car', 'city',
    'name', 'team', 'minute', 'idea', 'body', 'back', 'parent', 'face',
    'level', 'office', 'door', 'health', 'art', 'war', 'history', 'result', 'change',
    'morning', 'reason', 'research', 'girl', 'guy', 'moment', 'air', 'teacher', 'force',
    'scientist', 'artist', 'musician', 'author', 'chef', 'mechanic', 'programmer',
    'engineer', 'architect', 'researcher', 'doctor', 'patient', 'professor',
    'artifacts', 'paintings', 'symphony', 'novel', 'software', 'blueprint', 'data',
    'recipe', 'manuscript', 'equipment', 'documents', 'vehicle', 'experiment',
}

ADJECTIVE_SET = {
    'new', 'good', 'first', 'last', 'long', 'great', 'little', 'own', 'other', 'old',
    'right', 'big', 'high', 'different', 'small', 'large', 'next', 'early', 'young', 'important',
    'few', 'public', 'bad', 'same', 'able', 'best', 'full', 'simple', 'left', 'late',
    'hard', 'real', 'top', 'whole', 'sure', 'better', 'free', 'special', 'clear', 'recent',
    'beautiful', 'strong', 'certain', 'open', 'red', 'difficult', 'available', 'likely',
    'short', 'single', 'current', 'wrong', 'past', 'fine', 'common', 'poor', 'natural',
    'significant', 'similar', 'hot', 'dead', 'happy', 'serious', 'ready', 'easy', 'effective',
    'ancient', 'complex', 'delicious', 'broken', 'efficient', 'comprehensive', 'historical',
    'rare', 'valuable', 'intricate', 'challenging', 'innovative', 'experimental',
}

PARTICIPLE_SET = {
    'being', 'having', 'doing', 'saying', 'going', 'getting', 'making', 'knowing',
    'studying', 'examining', 'analyzing', 'exploring', 'investigating', 'reviewing',
    'evaluating', 'testing', 'repairing', 'building', 'designing', 'creating',
    'developing', 'improving', 'fixing', 'completing', 'preparing', 'organizing',
    'presenting', 'explaining', 'discussing', 'publishing', 'researching', 'solving',
    'working', 'running', 'walking', 'talking', 'reading', 'writing', 'playing',
    'watching', 'listening', 'thinking', 'looking', 'waiting', 'trying', 'helping',
    'been', 'had', 'done', 'said', 'gone', 'made', 'known', 'thought', 'taken',
    'seen', 'come', 'found', 'given', 'told', 'worked', 'called', 'tried',
    'studied', 'examined', 'analyzed', 'explored', 'investigated', 'reviewed',
    'evaluated', 'tested', 'repaired', 'built', 'designed', 'created',
    'developed', 'improved', 'fixed', 'completed', 'prepared', 'organized',
    'presented', 'explained', 'discussed', 'published', 'researched', 'solved',
    'broken', 'written', 'shown', 'chosen', 'spoken', 'frozen',
}

NP_START_SET = NOUN_SET | ADJECTIVE_SET | {
    'the', 'a', 'an', 'this', 'that', 'these', 'those', 'my', 'your', 'his', 'her',
    'its', 'our', 'their', 'some', 'any', 'each', 'every', 'all', 'both', 'many',
    'few', 'several', 'other', 'another', 'no', 'one', 'two', 'three',
}

CLAUSE_START_SET = {
    'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
    'this', 'that', 'these', 'those', 'who', 'what', 'which', 'where', 'when', 'why', 'how',
    'the', 'a', 'an', 'my', 'your', 'his', 'her', 'its', 'our', 'their',
    'some', 'any', 'all', 'both', 'each', 'every', 'no', 'everyone', 'someone',
} | NOUN_SET

TARGET_CLASSES = {
    'infinitival_to': {'primary': 'VERB', 'word_sets': {'VERB': VERB_SET}},
    'modals': {'primary': 'VERB', 'word_sets': {'VERB': VERB_SET}},
    'determiners': {'primary': 'NOUN_OR_ADJ', 'word_sets': {'NOUN': NOUN_SET, 'ADJ': ADJECTIVE_SET}},
    'prepositions': {'primary': 'NP_START', 'word_sets': {'NP_START': NP_START_SET}},
    'auxiliaries': {'primary': 'PARTICIPLE', 'word_sets': {'PARTICIPLE': PARTICIPLE_SET}},
    'complementizers': {'primary': 'CLAUSE_START', 'word_sets': {'CLAUSE_START': CLAUSE_START_SET}},
}

print("Target classes defined:")
for family, config in TARGET_CLASSES.items():
    n_words = sum(len(ws) for ws in config['word_sets'].values())
    print(f"  {family}: {config['primary']} ({n_words} words)")

## 3. Word-Level Analyzer

In [None]:
class WordLevelAnalyzer:
    """Analyze predictions at word level (avoiding BPE artifacts)."""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self._cache = {}
    
    def is_word_start_token(self, token_id: int) -> bool:
        """Check if token is word-start (space-prefixed in GPT-2/Pythia)."""
        if token_id in self._cache:
            return self._cache[token_id][0]
        
        token_str = self.tokenizer.decode([token_id])
        if token_str in ['<|endoftext|>', '<unk>', '<pad>', '']:
            self._cache[token_id] = (False, None)
            return False
        
        is_start = token_str.startswith(' ') or token_str.startswith('\n')
        word = token_str.strip().lower().strip('.,!?;:"\'') if is_start else None
        self._cache[token_id] = (is_start, word)
        return is_start
    
    def get_word_from_token(self, token_id: int) -> Optional[str]:
        if token_id not in self._cache:
            self.is_word_start_token(token_id)
        return self._cache[token_id][1]
    
    def compute_class_mass(self, probs: torch.Tensor, word_sets: Dict[str, Set[str]], top_k: int = 1000) -> Dict[str, float]:
        """Compute probability mass for each word class."""
        top_k_probs, top_k_ids = torch.topk(probs, min(top_k, len(probs)))
        class_mass = {name: 0.0 for name in word_sets}
        
        for prob, token_id in zip(top_k_probs, top_k_ids):
            word = self.get_word_from_token(token_id.item())
            if word is None:
                continue
            for class_name, word_set in word_sets.items():
                if word in word_set:
                    class_mass[class_name] += prob.item()
        
        return class_mass

def truncate_context(text: str, cue_position: int, k: int) -> str:
    """Get last k words up to and including cue."""
    words = text.split()
    if k >= cue_position + 1:
        return ' '.join(words[:cue_position + 1])
    else:
        start = max(0, cue_position + 1 - k)
        return ' '.join(words[start:cue_position + 1])

## 4. Run Audit

In [None]:
# Select model
MODEL_NAME = 'gpt2'  # Options: 'gpt2', 'gpt2-medium', 'gpt2-large', 'EleutherAI/pythia-410m'

print(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
model = model.to(device)
print("Model loaded!")

In [None]:
# Run audit with context ablation
CONDITIONS = ['sentence', 'jabberwocky', 'full_scrambled', 'content_scrambled', 'function_scrambled', 'cue_deleted']
CONTEXT_LENGTHS = [1, 2, 4, 8, -1]  # -1 = full
TOP_K = 1000

analyzer = WordLevelAnalyzer(tokenizer)
results = []

total = len(stimuli) * len(CONDITIONS) * len(CONTEXT_LENGTHS)
print(f"Running audit ({total} iterations)...")

with tqdm(total=total) as pbar:
    for stim in stimuli:
        family = stim['cue_family']
        cue_pos = stim['cue_position']
        word_sets = TARGET_CLASSES[family]['word_sets']
        
        for cond in CONDITIONS:
            text = stim[cond]
            
            for k in CONTEXT_LENGTHS:
                # Get context
                if k == -1:
                    context = ' '.join(text.split()[:cue_pos + 1])
                    k_label = 'full'
                else:
                    context = truncate_context(text, cue_pos, k)
                    k_label = str(k)
                
                # Get predictions
                inputs = tokenizer(context, return_tensors='pt').to(device)
                with torch.no_grad():
                    outputs = model(**inputs)
                
                logits = outputs.logits[0, -1, :]
                probs = torch.softmax(logits, dim=-1).cpu()
                
                # Compute class mass
                class_mass = analyzer.compute_class_mass(probs, word_sets, top_k=TOP_K)
                target_mass = sum(class_mass.values())
                
                results.append({
                    'set_id': stim['set_id'],
                    'cue_family': family,
                    'cue_word': stim['cue_word'],
                    'condition': cond.upper(),
                    'context_k': k_label,
                    'target_mass': target_mass,
                    'class_mass': class_mass,
                })
                
                pbar.update(1)

print(f"\nAudit complete! {len(results)} results")

In [None]:
# Convert to DataFrame
df = pd.DataFrame(results)

# Show summary
print("Target mass by family × condition (k=full):")
summary = df[df['context_k'] == 'full'].pivot_table(
    values='target_mass', 
    index='cue_family', 
    columns='condition',
    aggfunc='mean'
)
print(summary.round(3))

## 5. Statistical Analysis

In [None]:
def fdr_correction(p_values, alpha=0.05):
    """Benjamini-Hochberg FDR correction."""
    n = len(p_values)
    sorted_idx = np.argsort(p_values)
    sorted_p = p_values[sorted_idx]
    adjusted = np.zeros(n)
    for i, p in enumerate(sorted_p):
        adjusted[i] = p * n / (i + 1)
    for i in range(n - 2, -1, -1):
        adjusted[i] = min(adjusted[i], adjusted[i + 1])
    adjusted = np.minimum(adjusted, 1.0)
    result = np.zeros(n)
    result[sorted_idx] = adjusted
    return result, result < alpha

def compute_contrasts(df, context_k='full'):
    """Compute key paired contrasts."""
    df_k = df[df['context_k'] == context_k]
    
    contrasts = [
        ('SENTENCE', 'JABBERWOCKY'),
        ('JABBERWOCKY', 'FULL_SCRAMBLED'),
        ('JABBERWOCKY', 'CONTENT_SCRAMBLED'),
        ('JABBERWOCKY', 'FUNCTION_SCRAMBLED'),
        ('JABBERWOCKY', 'CUE_DELETED'),
    ]
    
    results = []
    for family in df_k['cue_family'].unique():
        df_fam = df_k[df_k['cue_family'] == family]
        
        for cond_a, cond_b in contrasts:
            df_a = df_fam[df_fam['condition'] == cond_a].set_index('set_id')['target_mass']
            df_b = df_fam[df_fam['condition'] == cond_b].set_index('set_id')['target_mass']
            common = df_a.index.intersection(df_b.index)
            
            if len(common) == 0:
                continue
            
            x, y = df_a.loc[common].values, df_b.loc[common].values
            diff = np.mean(x) - np.mean(y)
            t_stat, p_val = stats.ttest_rel(x, y)
            d = np.mean(x - y) / np.std(x - y, ddof=1) if np.std(x - y, ddof=1) > 0 else 0
            
            results.append({
                'cue_family': family,
                'contrast': f"{cond_a} - {cond_b}",
                'mean_a': np.mean(x),
                'mean_b': np.mean(y),
                'diff': diff,
                't_stat': t_stat,
                'p_value': p_val,
                'cohens_d': d,
                'n': len(common),
            })
    
    df_results = pd.DataFrame(results)
    if len(df_results) > 0:
        p_adj, sig = fdr_correction(df_results['p_value'].values)
        df_results['p_adjusted'] = p_adj
        df_results['significant'] = sig
    
    return df_results

# Compute contrasts
contrasts_df = compute_contrasts(df)
print("Key contrasts (FDR-corrected):")
print(contrasts_df[['cue_family', 'contrast', 'diff', 'p_adjusted', 'cohens_d', 'significant']].round(4).to_string(index=False))

## 6. Generate Figures

In [None]:
# Figure settings
plt.rcParams.update({'font.size': 10, 'figure.dpi': 150})

CONDITION_COLORS = {
    'SENTENCE': '#2ecc71',
    'JABBERWOCKY': '#3498db',
    'FULL_SCRAMBLED': '#e74c3c',
    'CONTENT_SCRAMBLED': '#f39c12',
    'FUNCTION_SCRAMBLED': '#9b59b6',
    'CUE_DELETED': '#95a5a6',
}

CONDITION_ORDER = ['SENTENCE', 'JABBERWOCKY', 'FULL_SCRAMBLED', 'CONTENT_SCRAMBLED', 'FUNCTION_SCRAMBLED', 'CUE_DELETED']
FAMILY_ORDER = ['infinitival_to', 'modals', 'determiners', 'prepositions', 'auxiliaries', 'complementizers']
FAMILY_LABELS = {
    'infinitival_to': 'Infinitival TO',
    'modals': 'Modals',
    'determiners': 'Determiners',
    'prepositions': 'Prepositions',
    'auxiliaries': 'Auxiliaries',
    'complementizers': 'Complementizers',
}

In [None]:
# Figure 1: Slot Constraint by Condition
df_full = df[df['context_k'] == 'full']

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for idx, family in enumerate(FAMILY_ORDER):
    ax = axes[idx]
    df_fam = df_full[df_full['cue_family'] == family]
    
    summary = df_fam.groupby('condition')['target_mass'].agg(['mean', 'std', 'count'])
    summary['se'] = summary['std'] / np.sqrt(summary['count'])
    
    x = np.arange(len(CONDITION_ORDER))
    means = [summary.loc[c, 'mean'] if c in summary.index else 0 for c in CONDITION_ORDER]
    ses = [summary.loc[c, 'se'] if c in summary.index else 0 for c in CONDITION_ORDER]
    colors = [CONDITION_COLORS[c] for c in CONDITION_ORDER]
    
    ax.bar(x, means, yerr=ses, capsize=3, color=colors, alpha=0.8, edgecolor='white')
    ax.set_title(FAMILY_LABELS[family], fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(['Sent', 'Jab', 'Full', 'Cont', 'Func', 'Cue'], rotation=45, ha='right')
    ax.set_ylabel('Target Class Mass')
    ax.set_ylim(0, None)

plt.suptitle(f'Morphosyntactic Slot Constraint ({MODEL_NAME})', fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/figure1_slot_constraint.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Figure 2: Paired Differences
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

contrasts_to_plot = [
    ('JABBERWOCKY', 'FULL_SCRAMBLED', 'JAB-Full'),
    ('JABBERWOCKY', 'CONTENT_SCRAMBLED', 'JAB-Cont'),
    ('JABBERWOCKY', 'FUNCTION_SCRAMBLED', 'JAB-Func'),
]

for idx, family in enumerate(FAMILY_ORDER):
    ax = axes[idx]
    df_fam = df_full[df_full['cue_family'] == family]
    
    diff_data = []
    labels = []
    
    for cond_a, cond_b, label in contrasts_to_plot:
        df_a = df_fam[df_fam['condition'] == cond_a].set_index('set_id')['target_mass']
        df_b = df_fam[df_fam['condition'] == cond_b].set_index('set_id')['target_mass']
        common = df_a.index.intersection(df_b.index)
        if len(common) > 0:
            diffs = (df_a.loc[common] - df_b.loc[common]).values
            diff_data.append(diffs)
            labels.append(label)
    
    if diff_data:
        parts = ax.violinplot(diff_data, showmeans=True)
        for i, pc in enumerate(parts['bodies']):
            pc.set_facecolor(['#3498db', '#f39c12', '#9b59b6'][i])
            pc.set_alpha(0.7)
        ax.axhline(y=0, color='red', linestyle='--', linewidth=1.5)
        ax.set_xticks(range(1, len(labels) + 1))
        ax.set_xticklabels(labels)
    
    ax.set_title(FAMILY_LABELS[family], fontweight='bold')
    ax.set_ylabel('Difference')

plt.suptitle('Paired Differences (Jabberwocky vs Scrambled)', fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/figure2_paired_differences.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Figure 3: Context Ablation
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

k_order = ['1', '2', '4', '8', 'full']
k_numeric = {'1': 1, '2': 2, '4': 4, '8': 8, 'full': 16}
ablation_conds = ['SENTENCE', 'JABBERWOCKY', 'FULL_SCRAMBLED', 'FUNCTION_SCRAMBLED']

for idx, family in enumerate(FAMILY_ORDER):
    ax = axes[idx]
    df_fam = df[df['cue_family'] == family]
    
    for cond in ablation_conds:
        df_cond = df_fam[df_fam['condition'] == cond]
        x_vals, y_vals, y_errs = [], [], []
        
        for k in k_order:
            df_k = df_cond[df_cond['context_k'] == k]
            if len(df_k) > 0:
                x_vals.append(k_numeric[k])
                y_vals.append(df_k['target_mass'].mean())
                y_errs.append(df_k['target_mass'].std() / np.sqrt(len(df_k)))
        
        if x_vals:
            ax.errorbar(x_vals, y_vals, yerr=y_errs, marker='o',
                       label=cond.replace('_', ' ').title() if idx == 0 else '',
                       color=CONDITION_COLORS[cond], capsize=3, linewidth=2)
    
    ax.set_title(FAMILY_LABELS[family], fontweight='bold')
    ax.set_xlabel('Context Length (k)')
    ax.set_ylabel('Target Class Mass')
    ax.set_xscale('log', base=2)
    ax.set_xticks([1, 2, 4, 8, 16])
    ax.set_xticklabels(['1', '2', '4', '8', 'full'])

fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.02), ncol=4)
plt.suptitle('Context Ablation', fontweight='bold', y=1.08)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/figure3_context_ablation.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Save Results

In [None]:
# ============================================================
# SAVE ALL RESULTS TO CSV AND JSON
# ============================================================

model_slug = MODEL_NAME.replace('/', '_')
print(f"Saving results for model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print("=" * 60)

# 1. SAVE RAW RESULTS AS JSON (complete with class_mass dict)
results_json_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_raw.json'
with open(results_json_file, 'w') as f:
    json.dump({
        'metadata': {
            'model': MODEL_NAME,
            'timestamp': datetime.now().isoformat(),
            'n_stimuli': len(stimuli),
            'n_results': len(results),
            'conditions': CONDITIONS,
            'context_lengths': CONTEXT_LENGTHS,
        },
        'results': results
    }, f, indent=2)
print(f"[1] Raw results (JSON): {results_json_file}")

# 2. SAVE RAW RESULTS AS CSV (flattened - one row per observation)
# Flatten class_mass dict into columns
results_flat = []
for r in results:
    row = {
        'set_id': r['set_id'],
        'cue_family': r['cue_family'],
        'cue_word': r['cue_word'],
        'condition': r['condition'],
        'context_k': r['context_k'],
        'target_mass': r['target_mass'],
    }
    # Add class_mass as separate columns
    for class_name, mass in r['class_mass'].items():
        row[f'mass_{class_name}'] = mass
    results_flat.append(row)

df_flat = pd.DataFrame(results_flat)
results_csv_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_raw.csv'
df_flat.to_csv(results_csv_file, index=False)
print(f"[2] Raw results (CSV):  {results_csv_file}")
print(f"    Shape: {df_flat.shape[0]} rows × {df_flat.shape[1]} columns")

# 3. SAVE SUMMARY TABLE (mean target_mass by family × condition, k=full)
summary_pivot = df[df['context_k'] == 'full'].pivot_table(
    values='target_mass', 
    index='cue_family', 
    columns='condition',
    aggfunc='mean'
)
summary_csv_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_summary.csv'
summary_pivot.to_csv(summary_csv_file)
print(f"[3] Summary table (CSV): {summary_csv_file}")

# Also save with std and n
summary_detailed = df[df['context_k'] == 'full'].groupby(['cue_family', 'condition'])['target_mass'].agg(
    ['mean', 'std', 'count']
).reset_index()
summary_detailed['se'] = summary_detailed['std'] / np.sqrt(summary_detailed['count'])
summary_detailed_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_summary_detailed.csv'
summary_detailed.to_csv(summary_detailed_file, index=False)
print(f"[4] Summary detailed (CSV): {summary_detailed_file}")

# 4. SAVE STATISTICAL CONTRASTS
contrasts_csv_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_contrasts.csv'
contrasts_df.to_csv(contrasts_csv_file, index=False)
print(f"[5] Statistical contrasts (CSV): {contrasts_csv_file}")

# 5. SAVE CONTEXT ABLATION SUMMARY (JABBERWOCKY by context_k)
ablation_jab = df[df['condition'] == 'JABBERWOCKY'].pivot_table(
    values='target_mass', 
    index='cue_family', 
    columns='context_k', 
    aggfunc='mean'
)
# Reorder columns
col_order = ['1', '2', '4', '8', 'full']
ablation_jab = ablation_jab[[c for c in col_order if c in ablation_jab.columns]]
ablation_csv_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_ablation.csv'
ablation_jab.to_csv(ablation_csv_file)
print(f"[6] Context ablation (CSV): {ablation_csv_file}")

# 6. SAVE ALL CONDITIONS ABLATION
ablation_all = df.pivot_table(
    values='target_mass',
    index=['cue_family', 'condition'],
    columns='context_k',
    aggfunc='mean'
)
ablation_all = ablation_all[[c for c in col_order if c in ablation_all.columns]]
ablation_all_file = f'{OUTPUT_DIR}/locked_audit_{model_slug}_ablation_all_conditions.csv'
ablation_all.to_csv(ablation_all_file)
print(f"[7] Ablation all conditions (CSV): {ablation_all_file}")

# 7. SAVE STIMULI
stimuli_json_file = f'{OUTPUT_DIR}/stimuli_locked.json'
with open(stimuli_json_file, 'w') as f:
    json.dump(stimuli, f, indent=2)
print(f"[8] Stimuli (JSON): {stimuli_json_file}")

# Also save stimuli as CSV
stimuli_df = pd.DataFrame(stimuli)
stimuli_csv_file = f'{OUTPUT_DIR}/stimuli_locked.csv'
stimuli_df.to_csv(stimuli_csv_file, index=False)
print(f"[9] Stimuli (CSV): {stimuli_csv_file}")

print()
print("=" * 60)
print("ALL FILES SAVED SUCCESSFULLY")
print("=" * 60)
print()
print("Files in output directory:")
import glob
for f in sorted(glob.glob(f'{OUTPUT_DIR}/*')):
    print(f"  {f.split('/')[-1]}")

## 8. Interpretation Guide

In [None]:
print("="*60)
print("INTERPRETATION GUIDE")
print("="*60)
print()
print("Key contrasts to examine:")
print()
print("1. JABBERWOCKY vs FULL_SCRAMBLED")
print("   If JAB >> FULL: Structure matters for morphosyntax")
print()
print("2. JABBERWOCKY vs FUNCTION_SCRAMBLED")
print("   If JAB >> FUNC_S: Function-word skeleton is NECESSARY")
print()
print("3. JABBERWOCKY vs CONTENT_SCRAMBLED")
print("   If JAB ≈ CONT_S (p > 0.05): Content order doesn't matter")
print("   → Skeleton is SUFFICIENT")
print()
print("4. SENTENCE vs JABBERWOCKY")
print("   If SENT ≈ JAB: Nonce substitution doesn't hurt")
print()
print("="*60)
print("Expected pattern for 'skeleton sufficiency':")
print("  SENTENCE ≈ JABBERWOCKY ≈ CONTENT_SCRAMBLED >> FUNCTION_SCRAMBLED ≈ FULL_SCRAMBLED")
print("="*60)