In [85]:
import pandas as pd
import numpy as np

In [86]:
stories = pd.read_csv('/scratch2/mrenaudin/Hard-CBR-RNN/all_stories.tok', sep = '\t')

In [None]:
stories

## Predict surprisal

### Load model

In [88]:
from original_cbr import CBR_RNN
checkpoint = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/checkpoints/epoch=49-step=565950.ckpt'
hparams_path = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/hparams.yaml'

In [89]:
import yaml
import torch
with open(hparams_path, 'r') as f:
            hparams = yaml.safe_load(f)

In [None]:
ntoken = hparams['vocab_size']+1  # vocabulary size
ninp = hparams['ninp']      # embedding dimension
nhid = hparams['nhid']      # hidden dimension
nlayers = hparams.get('nlayers', 1)  # number of layers
nheads = hparams.get('nheads', 1)    # number of attention heads
dropout = hparams.get('dropout', 0.5) # dropout rate

# Initialize model
model = CBR_RNN(
    ntoken=ntoken,
    ninp=ninp, 
    nhid=nhid,
    nlayers=nlayers,
    nheads=nheads,
    dropout=dropout
)

# Load checkpoint
checkpoint = torch.load(checkpoint, map_location='cpu')

# Extract state_dict (might be nested in Lightning checkpoint)
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
    # Remove Lightning module prefix if present
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
else:
    state_dict = checkpoint

model.load_state_dict(state_dict)
model.eval()

## New method

In [130]:
from original_cbr_lightning import CBRLanguageModel
from grid_search import WordTokenizer
import torch

In [159]:
checkpoint = '/scratch2/mrenaudin/Hard-CBR-RNN/job_004/lightning_logs/version_1198535/checkpoints/epoch=49-step=565950.ckpt'


In [160]:
def load_trained_model(checkpoint_path):
    """Load the trained Lightning model"""
    model = CBRLanguageModel.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model

In [161]:
model = load_trained_model(checkpoint)

In [153]:
def compute_surprisal_with_chunking(lightning_model, stories_df, tokenizer, chunk_size=35):
    """
    Compute surprisal using the same chunking as training (seq_len=35)
    
    Args:
        lightning_model: Loaded Lightning CBRLanguageModel
        stories_df: DataFrame with columns [word, zone, item]
        tokenizer: WordTokenizer object
        chunk_size: Sequence length used during training (35)
    
    Returns:
        DataFrame with added 'surprisal' column
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lightning_model.to(device)
    
    # Extract the actual CBR_RNN model
    model = lightning_model.model
    
    results = []
    
    # Process each story
    for story_id in sorted(stories_df['item'].unique()):
        story_data = stories_df[stories_df['item'] == story_id].sort_values('zone')
        words = story_data['word'].tolist()
        
        print(f"Processing story {story_id} ({len(words)} words)...")
        
        # Convert words to token IDs using the tokenizer
        token_ids = []
        for word in words:
            token_id = tokenizer.stoi.get(word, 0)  # 0 = <unk>
            token_ids.append(token_id)
        
        # Process in chunks of 35 tokens (matching training)
        story_surprisals = [float('nan')] * len(words)  # Initialize with NaN
        
        for start_idx in range(0, len(token_ids), chunk_size):
            end_idx = min(start_idx + chunk_size, len(token_ids))
            chunk_ids = token_ids[start_idx:end_idx]
            
            if len(chunk_ids) < 2:  # Need at least 2 tokens for prediction
                continue
            
            # Convert to tensor [chunk_len, batch_size=1]
            input_tensor = torch.tensor(chunk_ids).unsqueeze(1).to(device)
            
            # Initialize cache for this chunk
            initial_cache = model.init_cache(input_tensor)
            
            # Forward pass (same as training)
            with torch.no_grad():
                # Use same parameters as training
                forward_kwargs = {
                    'observation': input_tensor,
                    'initial_cache': initial_cache
                }
                
                # Add Gumbel parameters if model was trained with them
                if hasattr(lightning_model, 'use_gumbel_softmax') and lightning_model.use_gumbel_softmax:
                    forward_kwargs.update({
                        'temperature': 0.1,  # Use fixed temp for inference
                        'use_gumbel': True   # Don't use Gumbel for inference
                    })
                
                logits, states = model(**forward_kwargs)
                # logits shape: [chunk_len, 1, vocab_size]
                
                # Compute log probabilities
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                
                # Compute surprisal for each position in chunk
                for i in range(1, len(chunk_ids)):  # Skip first token (no context)
                    target_token_id = chunk_ids[i]
                    # Use prediction from step i-1 for token at step i
                    log_prob = log_probs[i-1, 0, target_token_id].item()
                    surprisal = -log_prob
                    
                    # Map back to story position
                    story_pos = start_idx + i
                    if story_pos < len(story_surprisals):
                        story_surprisals[story_pos] = surprisal
        
        # Collect results for this story
        for idx, (_, row) in enumerate(story_data.iterrows()):
            results.append({
                'word': row['word'],
                'zone': row['zone'],
                'item': row['item'],
                'surprisal': story_surprisals[idx]
            })
    
    return pd.DataFrame(results)

In [100]:
tokenizer = WordTokenizer.load('tokenizer.json')


In [None]:
surprisal = compute_surprisal_with_chunking(model, stories, tokenizer, chunk_size=35)

In [None]:
surprisal


## Prepare data for reading time prediction

In [103]:
rt_data = pd.read_csv('/scratch2/mrenaudin/Hard-CBR-RNN/processed_RTs.tsv', sep='\t')

In [None]:
rt_data

In [105]:
data = rt_data.merge(surprisal, on=['item', 'zone'], how='inner')

In [106]:
data['log_RT'] = np.log(data['RT'])                    # y variable
data['word_length'] = data['word_x'].str.len()           # create length from word

# Clean grouping variable - convert to sequential integers
data['subject'] = pd.Categorical(data['WorkerId']).codes

# Standardize predictors  
data['surprisal_z'] = (data['surprisal'] - data['surprisal'].mean()) / data['surprisal'].std()
data['length_z'] = (data['word_length'] - data['word_length'].mean()) / data['word_length'].std()

# Reset index to avoid indexing issues
data = data.reset_index(drop=True)

In [None]:
print(f"Data prepared: {len(data)} observations")
print(f"Number of subjects: {data['subject'].nunique()}")
print(f"Mean word length: {data['word_length'].mean():.1f} characters")
    

## Fit model

In [None]:
from statsmodels.regression.mixed_linear_model import MixedLM

y = data['log_RT'].values
X = data[['surprisal_z', 'length_z']].values
X = np.column_stack([np.ones(len(X)), X])  # Add intercept
groups = data['subject'].values

# Remove any NaN values
mask = ~(np.isnan(y) | np.isnan(X).any(axis=1) | np.isnan(groups))
y = y[mask]
X = X[mask]
groups = groups[mask]

print(f"Clean data: {len(y)} observations")

model = MixedLM(y, X, groups=groups)
result = model.fit()

In [109]:
def print_results(fitted_model):
    """Print key results"""
    print("="*50)
    print("READING TIME PREDICTION RESULTS")
    print("="*50)
    
    # Check if it's a fitted model
    if hasattr(fitted_model, 'summary'):
        print(fitted_model.summary().tables[1])  # Coefficients table only
        
        # Get coefficients using array indexing
        surprisal_coef = fitted_model.params[1]  # x1 = surprisal
        surprisal_p = fitted_model.pvalues[1]
        
        length_coef = fitted_model.params[2]     # x2 = word length  
        length_p = fitted_model.pvalues[2]
            
        print(f"\nSurprisal coefficient (x1): {surprisal_coef:.4f}")
        print(f"P-value: {surprisal_p:.4f}")
        print(f"Significant: {'Yes' if surprisal_p < 0.05 else 'No'}")
        
        print(f"\nWord length coefficient (x2): {length_coef:.4f}")
        print(f"P-value: {length_p:.4f}")
        
        # Interpretation
        if surprisal_coef > 0:
            print("✓ Higher surprisal → Longer reading times (expected)")
        else:
            print("⚠ Higher surprisal → Shorter reading times (unexpected)")
            print("  This suggests an issue with surprisal computation or model")
            
    else:
        print("Model not properly fitted!")
        print(f"Model type: {type(fitted_model)}")



In [None]:
print_results(result)

## Following methodology from clark

In [None]:
# Merge RT and surprisal data
data = rt_data.merge(surprisal, on=['item', 'zone'], how='inner')

# Filter for correct responses and reasonable RTs
data = data[(data['correct'] >= 5) & (data['RT'] >= 100) & (data['RT'] <= 3000)]

# Sort by item and zone to ensure proper ordering for spillover
data = data.sort_values(['item', 'zone']).reset_index(drop=True)

print(f"Data after filtering: {len(data)} observations")

# Create baseline predictors exactly as in paper
data['word_length'] = data['word_x'].str.len()  # Word length in characters
data['word_position'] = data['zone']          # Index of word position within sentence

# Create unigram surprisal (simplified - you'd normally use KenLM on OpenWebText)
# For now, using word frequency as proxy for unigram surprisal
word_counts = data['word_x'].value_counts()
total_words = len(data)
data['word_freq'] = data['word_x'].map(word_counts)
data['unigram_surprisal'] = -np.log(data['word_freq'] / total_words)

# Current word surprisal (your model's output)
data['current_surprisal'] = data['surprisal']

# Previous word surprisal (spillover effects)
data['prev_surprisal'] = data.groupby('item')['surprisal'].shift(1)

# Remove rows where previous word spillover is NaN (first word of each story)
data = data.dropna(subset=['prev_surprisal']).reset_index(drop=True)

print(f"Data after spillover calculation: {len(data)} observations")

# Create subject grouping variable
data['subject'] = pd.Categorical(data['WorkerId']).codes

# Use RAW reading times as specified in paper (not log-transformed)
data['RT_raw'] = data['RT']

print(f"Final dataset: {len(data)} observations, {data['subject'].nunique()} subjects")

In [123]:
def fit_baseline_model(data):
    """
    Fit baseline model without surprisal predictors
    Use manual array approach for consistency
    """
    
    print("Attempting baseline model with manual arrays...")
    
    try:
        # Create design matrix: intercept, word_length, word_position, unigram_surprisal
        X_baseline = np.column_stack([
            np.ones(len(data)),                    # intercept
            data['word_length'].values,            # word length
            data['word_position'].values,          # word position
            data['unigram_surprisal'].values       # unigram surprisal
        ])
        
        y = data['RT_raw'].values
        groups = data['subject'].values
        
        # Remove any NaN values
        mask = ~(np.isnan(y) | np.isnan(X_baseline).any(axis=1) | np.isnan(groups))
        y = y[mask]
        X_baseline = X_baseline[mask]
        groups = groups[mask]
        
        print(f"Clean data for baseline: {len(y)} observations")
        print(f"Baseline design matrix shape: {X_baseline.shape}")
        
        # Fit with random intercepts
        model = MixedLM(y, X_baseline, groups=groups)
        result = model.fit(method='lbfgs')
        print("Baseline model fitted successfully")
        
        # Add coefficient names
        result.params_names = ['intercept', 'word_length', 'word_position', 'unigram_surprisal']
        
        return result
        
    except Exception as e:
        print(f"Baseline model failed: {e}")
        
        # Fallback to OLS
        print("Using OLS for baseline model...")
        import statsmodels.api as sm
        ols_result = sm.OLS(y, X_baseline).fit()
        ols_result.params_names = ['intercept', 'word_length', 'word_position', 'unigram_surprisal']
        
        return ols_result

In [124]:
def fit_full_model(data):
    """
    Fit full model with current and previous word surprisal
    Use manual array approach to avoid formula parsing issues
    """
    
    print("Attempting full model with manual arrays...")
    
    try:
        # Create design matrix manually
        # Fixed effects: intercept, word_length, word_position, unigram_surprisal, current_surprisal, prev_surprisal
        X = np.column_stack([
            np.ones(len(data)),                    # intercept
            data['word_length'].values,            # word length
            data['word_position'].values,          # word position  
            data['unigram_surprisal'].values,      # unigram surprisal
            data['current_surprisal'].values,      # current surprisal
            data['prev_surprisal'].values          # previous surprisal
        ])
        
        y = data['RT_raw'].values
        groups = data['subject'].values
        
        # Remove any NaN values
        mask = ~(np.isnan(y) | np.isnan(X).any(axis=1) | np.isnan(groups))
        y = y[mask]
        X = X[mask]
        groups = groups[mask]
        
        print(f"Clean data for full model: {len(y)} observations")
        print(f"Design matrix shape: {X.shape}")
        
        # Start with simplest random effects and build up
        try:
            # Try with random intercepts only first
            model = MixedLM(y, X, groups=groups)
            result = model.fit(method='lbfgs')
            print("Full model fitted with random intercepts only")
            
            # Add coefficient names for interpretation
            result.params_names = ['intercept', 'word_length', 'word_position', 
                                 'unigram_surprisal', 'current_surprisal', 'prev_surprisal']
            
            return result
            
        except Exception as e:
            print(f"Even simple random effects failed: {e}")
            
            # Ultimate fallback: OLS
            print("Falling back to OLS for full model...")
            import statsmodels.api as sm
            ols_result = sm.OLS(y, X).fit()
            
            # Add names for interpretation
            ols_result.params_names = ['intercept', 'word_length', 'word_position',
                                     'unigram_surprisal', 'current_surprisal', 'prev_surprisal']
            
            return ols_result
            
    except Exception as e:
        print(f"Manual array approach failed: {e}")
        raise

In [125]:
def calculate_delta_loglik(baseline_result, full_result):
    """
    Calculate ∆LogLik as in the paper - robust version
    """
    
    # Get log-likelihood values
    if hasattr(baseline_result, 'llf'):
        baseline_ll = baseline_result.llf
    else:
        baseline_ll = getattr(baseline_result, 'loglike', np.nan)
        
    if hasattr(full_result, 'llf'):
        full_ll = full_result.llf  
    else:
        full_ll = getattr(full_result, 'loglike', np.nan)
    
    delta_loglik = full_ll - baseline_ll
    
    print("\n" + "="*60)
    print("MODEL COMPARISON RESULTS (Following Paper)")
    print("="*60)
    
    print(f"Baseline Log-Likelihood: {baseline_ll:.2f}")
    print(f"Full Model Log-Likelihood: {full_ll:.2f}")
    print(f"∆LogLik (improvement): {delta_loglik:.2f}")
    
    if not np.isnan(delta_loglik):
        # Likelihood ratio test
        lr_stat = 2 * delta_loglik
        df_diff = len(full_result.params) - len(baseline_result.params)
        
        from scipy.stats import chi2
        p_value = 1 - chi2.cdf(lr_stat, df_diff)
        
        print(f"Likelihood Ratio Test: χ² = {lr_stat:.2f}, df = {df_diff}, p = {p_value:.4f}")
        print(f"Significance: {'***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else 'n.s.'}")
        
        return {
            'delta_loglik': delta_loglik,
            'lr_stat': lr_stat,
            'p_value': p_value,
            'baseline_ll': baseline_ll,
            'full_ll': full_ll
        }
    else:
        print("Could not compute likelihood ratio test")
        return {
            'delta_loglik': delta_loglik,
            'baseline_ll': baseline_ll,
            'full_ll': full_ll
        }

In [126]:
def print_full_model_results(full_result):
    """Print detailed results for the full model"""
    
    print("\n" + "="*60)
    print("FULL MODEL COEFFICIENTS")
    print("="*60)
    
    if hasattr(full_result, 'summary'):
        try:
            print(full_result.summary().tables[1])
        except:
            # Manual printing if summary fails
            print("Coefficients:")
            if hasattr(full_result, 'params_names'):
                for i, name in enumerate(full_result.params_names):
                    coef = full_result.params[i]
                    pval = full_result.pvalues[i] if hasattr(full_result, 'pvalues') else 'N/A'
                    print(f"  {name}: {coef:.4f} (p = {pval})")
    
    # Extract surprisal coefficients using our added names
    if hasattr(full_result, 'params_names'):
        names = full_result.params_names
        params = full_result.params
        pvalues = full_result.pvalues if hasattr(full_result, 'pvalues') else [np.nan] * len(params)
        
        # Find current and previous surprisal coefficients
        current_idx = names.index('current_surprisal') if 'current_surprisal' in names else None
        prev_idx = names.index('prev_surprisal') if 'prev_surprisal' in names else None
        
        print(f"\n" + "="*40)
        print("SURPRISAL EFFECTS")
        print("="*40)
        
        if current_idx is not None:
            current_coef = params[current_idx]
            current_p = pvalues[current_idx]
            print(f"Current word surprisal: {current_coef:.4f} (p = {current_p:.4f})")
            sig_current = '***' if current_p < 0.001 else '**' if current_p < 0.01 else '*' if current_p < 0.05 else 'n.s.'
            print(f"  Significance: {sig_current}")
            
        if prev_idx is not None:
            prev_coef = params[prev_idx]
            prev_p = pvalues[prev_idx]
            print(f"Previous word surprisal (spillover): {prev_coef:.4f} (p = {prev_p:.4f})")
            sig_prev = '***' if prev_p < 0.001 else '**' if prev_p < 0.01 else '*' if prev_p < 0.05 else 'n.s.'
            print(f"  Significance: {sig_prev}")
    
    else:
        print("\nCannot extract specific surprisal coefficients - check model structure")


In [127]:
def run_paper_analysis(data):
    """
    Complete analysis following the paper's methodology
    
    Returns model comparison results and fitted models
    """
    
    print("="*60)
    print("RUNNING ANALYSIS FOLLOWING PAPER SPECIFICATION")
    print("="*60)
    
    # Step 1: Prepare data with all required predictors
    # print("\nStep 1: Preparing data...")
    # data = prepare_data_paper_spec(rt_data, surprisal_data)
    
    # Step 2: Fit baseline model (without surprisal)
    print("\nStep 2: Fitting baseline model...")
    baseline_result = fit_baseline_model(data)
    
    # Step 3: Fit full model (with current and previous surprisal)
    print("\nStep 3: Fitting full model...")
    full_result = fit_full_model(data)
    
    # Step 4: Calculate ∆LogLik
    print("\nStep 4: Calculating model comparison...")
    comparison = calculate_delta_loglik(baseline_result, full_result)
    
    # Step 5: Print detailed results
    print_full_model_results(full_result)
    
    return {
        'data': data,
        'baseline_model': baseline_result,
        'full_model': full_result,
        'comparison': comparison
    }


In [None]:
results = run_paper_analysis(data)

## Final good methodology

In [140]:
import pandas as pd
import numpy as np
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

def prepare_data_correct(rt_data, surprisal_data):
    """
    Prepare data exactly as specified in paper
    """
    # Merge and filter - handle duplicate column names
    data = rt_data.merge(surprisal_data, on=['item', 'zone'], how='inner', suffixes=('_rt', '_surprisal'))
    data = data[(data['correct'] >= 5) & (data['RT'] >= 100) & (data['RT'] <= 3000)]
    data = data.sort_values(['item', 'zone']).reset_index(drop=True)
    
    # Use the word column from surprisal data (should be the same)
    if 'word_surprisal' in data.columns:
        data['word'] = data['word_surprisal']
    elif 'word_rt' in data.columns:
        data['word'] = data['word_rt']
    elif 'word' in data.columns:
        pass  # Already have word column
    else:
        # Check what columns we actually have
        print("Available columns:", data.columns.tolist())
        raise ValueError("Cannot find word column after merge")
    
    # Create exact predictors from paper
    data['word_length'] = data['word'].str.len()
    data['word_position'] = data['zone'] 
    
    # Unigram surprisal - fix the duplicate index issue
    word_counts = data['word'].value_counts()
    total_words = len(data)
    data['unigram_surprisal'] = data['word'].map(lambda w: -np.log(word_counts.get(w, 1) / total_words))
    
    # Current and previous word surprisal
    data['current_surprisal'] = data['surprisal']
    data['prev_surprisal'] = data.groupby('item')['surprisal'].shift(1)
    
    # Remove first words (no previous surprisal)
    data = data.dropna(subset=['prev_surprisal']).reset_index(drop=True)
    
    # Subject coding
    data['subject'] = pd.Categorical(data['WorkerId']).codes
    
    print(f"Final data: {len(data)} observations, {data['subject'].nunique()} subjects")
    return data


def fit_simple_lmer(y, X, groups, max_attempts=3):
    """
    Robust LME fitting with fallbacks to handle numerical issues
    """
    from statsmodels.regression.mixed_linear_model import MixedLM
    import statsmodels.api as sm
    
    # Clean data
    mask = ~(np.isnan(y) | np.isnan(X).any(axis=1) | np.isnan(groups))
    y_clean = y[mask]
    X_clean = X[mask]
    groups_clean = groups[mask]
    
    attempts = [
        ("MixedLM with random intercepts", lambda: MixedLM(y_clean, X_clean, groups=groups_clean).fit(method='nm', maxiter=100)),
        ("MixedLM with BFGS", lambda: MixedLM(y_clean, X_clean, groups=groups_clean).fit(method='bfgs', maxiter=50)),
        ("OLS fallback", lambda: sm.OLS(y_clean, X_clean).fit())
    ]
    
    for name, fit_func in attempts[:max_attempts]:
        try:
            result = fit_func()
            if hasattr(result, 'llf') and np.isfinite(result.llf):
                print(f"  Fitted with {name}, LogLik = {result.llf:.2f}")
                return result
            elif hasattr(result, 'llf'):
                print(f"  {name} converged but LogLik infinite")
        except Exception as e:
            print(f"  {name} failed: {str(e)[:50]}...")
            continue
    
    raise Exception("All fitting methods failed")

def compute_delta_loglik_correct(data):
    """
    Compute ΔLogLik exactly as in paper:
    Compare baseline vs baseline + current_surprisal + prev_surprisal
    """
    
    print("\n" + "="*60)
    print("FITTING MODELS FOR ΔLOGLIK CALCULATION")
    print("="*60)
    
    # Prepare design matrices
    y = data['RT'].values  # Raw RT as specified
    
    # Baseline: intercept + word_length + word_position + unigram_surprisal
    X_baseline = np.column_stack([
        np.ones(len(data)),
        data['word_length'].values,
        data['word_position'].values, 
        data['unigram_surprisal'].values
    ])
    
    # Full: baseline + current_surprisal + prev_surprisal
    X_full = np.column_stack([
        X_baseline,
        data['current_surprisal'].values,
        data['prev_surprisal'].values
    ])
    
    groups = data['subject'].values
    
    print(f"Data shape: {len(y)} observations")
    print(f"Baseline predictors: {X_baseline.shape[1]} (intercept + word_length + word_position + unigram_surprisal)")
    print(f"Full predictors: {X_full.shape[1]} (baseline + current_surprisal + prev_surprisal)")
    
    # Fit baseline model
    print("\nFitting baseline model...")
    baseline_result = fit_simple_lmer(y, X_baseline, groups)
    
    # Fit full model  
    print("\nFitting full model...")
    full_result = fit_simple_lmer(y, X_full, groups)
    
    # Calculate ΔLogLik
    baseline_ll = baseline_result.llf
    full_ll = full_result.llf
    delta_loglik = full_ll - baseline_ll
    
    print(f"\n" + "="*60)
    print("ΔLOGLIK RESULTS (PAPER SPECIFICATION)")
    print("="*60)
    print(f"Baseline Log-Likelihood: {baseline_ll:.2f}")
    print(f"Full Model Log-Likelihood: {full_ll:.2f}")  
    print(f"ΔLogLik: {delta_loglik:.2f}")
    
    if np.isfinite(delta_loglik):
        # Likelihood ratio test
        lr_stat = 2 * delta_loglik
        df_diff = X_full.shape[1] - X_baseline.shape[1]  # Should be 2 (current + prev surprisal)
        p_value = 1 - stats.chi2.cdf(lr_stat, df_diff)
        
        print(f"Likelihood Ratio Test: χ² = {lr_stat:.2f}, df = {df_diff}, p = {p_value:.6f}")
        significance = '***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else 'n.s.'
        print(f"Significance: {significance}")
        
        # Effect interpretation
        if delta_loglik > 0:
            print(f"✓ Your CBR-RNN surprisal improves model fit by {delta_loglik:.2f} log-likelihood units")
        else:
            print(f"✗ Surprisal does not improve model fit (ΔLogLik = {delta_loglik:.2f})")
    else:
        print("⚠ ΔLogLik calculation failed due to numerical issues")
        print("Models fitted but log-likelihood values are problematic")
    
    # Extract surprisal coefficients from full model
    print(f"\n" + "="*40)
    print("SURPRISAL COEFFICIENTS")
    print("="*40)
    
    if hasattr(full_result, 'params') and len(full_result.params) >= 6:
        current_coef = full_result.params[4]  # 5th coefficient (current surprisal)
        prev_coef = full_result.params[5]     # 6th coefficient (previous surprisal)
        
        if hasattr(full_result, 'pvalues'):
            current_p = full_result.pvalues[4]
            prev_p = full_result.pvalues[5]
        else:
            current_p = prev_p = np.nan
            
        print(f"Current word surprisal: {current_coef:.4f} (p = {current_p:.4f})")
        print(f"Previous word surprisal: {prev_coef:.4f} (p = {prev_p:.4f})")
        
        if current_coef > 0:
            print("✓ Current surprisal has expected positive effect")
        else:
            print("⚠ Current surprisal has unexpected negative effect")
    
    return {
        'delta_loglik': delta_loglik,
        'baseline_model': baseline_result,
        'full_model': full_result,
        'lr_test': {
            'statistic': lr_stat if 'lr_stat' in locals() else np.nan,
            'p_value': p_value if 'p_value' in locals() else np.nan,
            'df': df_diff if 'df_diff' in locals() else 2
        }
    }

def run_paper_analysis_correct(rt_data, surprisal_data):
    """
    Run the exact analysis from the paper to get ΔLogLik
    """
    
    print("="*60)  
    print("PAPER-COMPLIANT ΔLOGLIK ANALYSIS")
    print("="*60)
    
    # Prepare data
    print("Preparing data...")
    data = prepare_data_correct(rt_data, surprisal_data)
    
    # Compute ΔLogLik
    results = compute_delta_loglik_correct(data)
    
    # Summary
    print(f"\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    print(f"ΔLogLik for your CBR-RNN: {results['delta_loglik']:.2f}")
    
    if np.isfinite(results['delta_loglik']):
        if results['delta_loglik'] > 0:
            print("🎉 Your model improves prediction of human reading times!")
            print(f"Improvement: {results['delta_loglik']:.2f} log-likelihood units")
        else:
            print("❌ Your model does not improve reading time prediction")
    else:
        print("⚠ Numerical issues prevented ΔLogLik calculation")
    
    return results



In [None]:
# Your data
rt_data = pd.read_csv('processed_RTs.tsv', sep='\t')
surprisal_data = compute_surprisal_with_chunking(model, stories, tokenizer=tokenizer)

# Get the ΔLogLik metric from the paper
results = run_paper_analysis_correct(rt_data, surprisal_data)
delta_loglik = results['delta_loglik']

print(f"Your CBR-RNN ΔLogLik: {delta_loglik:.2f}")

In [None]:
from clean_RTs import prepare_data_correct, fit_simple_lmer, compute_delta_loglik_correct, run_paper_analysis_correct

check_128_1_false = '/scratch2/mrenaudin/Hard-CBR-RNN/job_000/lightning_logs/version_1198531/checkpoints/epoch=49-step=565950.ckpt'
check_512_1_false = '/scratch2/mrenaudin/Hard-CBR-RNN/job_001/lightning_logs/version_1198532/checkpoints/epoch=49-step=565950.ckpt'
check_128_8_false = '/scratch2/mrenaudin/Hard-CBR-RNN/job_002/lightning_logs/version_1198533/checkpoints/epoch=49-step=565950.ckpt'
check_512_8_false = '/scratch2/mrenaudin/Hard-CBR-RNN/job_003/lightning_logs/version_1198534/checkpoints/epoch=49-step=565950.ckpt'
check_128_1_true = '/scratch2/mrenaudin/Hard-CBR-RNN/job_004/lightning_logs/version_1198535/checkpoints/epoch=49-step=565950.ckpt'
check_512_1_true = '/scratch2/mrenaudin/Hard-CBR-RNN/job_005/lightning_logs/version_1198536/checkpoints/epoch=49-step=565950.ckpt'
check_128_8_true = '/scratch2/mrenaudin/Hard-CBR-RNN/job_006/lightning_logs/version_1198537/checkpoints/epoch=49-step=565950.ckpt'
check_512_8_true = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/checkpoints/epoch=49-step=565950.ckpt'

tokenizer = WordTokenizer.load('tokenizer.json')

In [None]:
from clean_RTs import prepare_data_correct, fit_simple_lmer, compute_delta_loglik_correct, run_paper_analysis_correct
from grid_search import WordTokenizer  # assuming your tokenizer class is here
import pandas as pd

# Load tokenizer
tokenizer = WordTokenizer.load('tokenizer.json')

# Define your checkpoints
checkpoints = {
    "128_1_false": '/scratch2/mrenaudin/Hard-CBR-RNN/job_000/lightning_logs/version_1198531/checkpoints/epoch=49-step=565950.ckpt',
    "512_1_false": '/scratch2/mrenaudin/Hard-CBR-RNN/job_001/lightning_logs/version_1198532/checkpoints/epoch=49-step=565950.ckpt',
    "128_8_false": '/scratch2/mrenaudin/Hard-CBR-RNN/job_002/lightning_logs/version_1198533/checkpoints/epoch=49-step=565950.ckpt',
    "512_8_false": '/scratch2/mrenaudin/Hard-CBR-RNN/job_003/lightning_logs/version_1198534/checkpoints/epoch=49-step=565950.ckpt',
    "128_1_true": '/scratch2/mrenaudin/Hard-CBR-RNN/job_004/lightning_logs/version_1198535/checkpoints/epoch=49-step=565950.ckpt',
    "512_1_true": '/scratch2/mrenaudin/Hard-CBR-RNN/job_005/lightning_logs/version_1198536/checkpoints/epoch=49-step=565950.ckpt',
    "128_8_true": '/scratch2/mrenaudin/Hard-CBR-RNN/job_006/lightning_logs/version_1198537/checkpoints/epoch=49-step=565950.ckpt',
    "512_8_true": '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/checkpoints/epoch=49-step=565950.ckpt'
}

# Placeholder dictionary to store ΔLogLik
delta_loglik_dict = {}

# Load your reading time and surprisal data
# (You need to have these prepared for your analysis)
rt_data = pd.read_csv('processed_RTs.tsv', sep='\t')

# Loop over checkpoints
for name, ckpt_path in checkpoints.items():
    print(f"\nRunning analysis for checkpoint: {name}")
    model = load_trained_model(ckpt_path)
    
    surprisal_data = compute_surprisal_with_chunking(model, stories, tokenizer=tokenizer)
    
    # Run ΔLogLik analysis
    results = run_paper_analysis_correct(rt_data, surprisal_data)
    
    # Store ΔLogLik
    delta_loglik_dict[name] = results['delta_loglik']

# Print results
print("\n=== ΔLogLik for all checkpoints ===")
for name, delta in delta_loglik_dict.items():
    print(f"{name}: {delta:.2f}")


In [165]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_factorial_delta_loglik(delta_loglik_dict):
    """
    Generate all plots for a 2x2x2 factorial design
    :param delta_loglik_dict: dict with keys like '128_1_true' and ΔLogLik values
    """
    # --- Prepare DataFrame ---
    rows = []
    for name, delta in delta_loglik_dict.items():
        hidden_dim, heads, gumbel = name.split("_")
        rows.append({
            'hidden_dim': int(hidden_dim),
            'heads': int(heads),
            'gumbel_softmax': gumbel == 'true',
            'delta_loglik': delta
        })
    df = pd.DataFrame(rows)
    
    sns.set(style="whitegrid")
    
    # --- 1. Main Effects ---
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    sns.barplot(data=df, x='hidden_dim', y='delta_loglik', ax=axes[0])
    axes[0].set_title("Main Effect: Hidden Dimension")
    
    sns.barplot(data=df, x='heads', y='delta_loglik', ax=axes[1])
    axes[1].set_title("Main Effect: Number of Heads")
    
    sns.barplot(data=df, x='gumbel_softmax', y='delta_loglik', ax=axes[2])
    axes[2].set_title("Main Effect: Gumbel Softmax")
    axes[2].set_xticklabels(["Classic", "Gumbel"])
    
    plt.tight_layout()
    plt.show()
    
    # --- 2. Two-way Interactions ---
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    sns.pointplot(data=df, x='hidden_dim', y='delta_loglik', hue='heads', ax=axes[0],
                  dodge=True, markers=["o", "s"], linestyles=["-", "--"])
    axes[0].set_title("Interaction: Hidden Dimension × Heads")
    
    sns.pointplot(data=df, x='hidden_dim', y='delta_loglik', hue='gumbel_softmax', ax=axes[1],
                  dodge=True, markers=["o", "s"], linestyles=["-", "--"])
    axes[1].set_title("Interaction: Hidden Dimension × Softmax")
    axes[1].set_xticklabels(["128", "512"])
    axes[1].legend(title="Gumbel Softmax", labels=["Classic", "Gumbel"])
    
    sns.pointplot(data=df, x='heads', y='delta_loglik', hue='gumbel_softmax', ax=axes[2],
                  dodge=True, markers=["o", "s"], linestyles=["-", "--"])
    axes[2].set_title("Interaction: Heads × Softmax")
    axes[2].legend(title="Gumbel Softmax", labels=["Classic", "Gumbel"])
    
    plt.tight_layout()
    plt.show()
    
    # --- 3. Three-way Interaction ---
    g = sns.catplot(
        data=df, kind="bar",
        x="hidden_dim", y="delta_loglik", hue="heads",
        col="gumbel_softmax", palette="muted", ci=None,
        height=5, aspect=0.8
    )
    g.set_axis_labels("Hidden Dimension", "ΔLogLik")
    g.set_titles("Gumbel Softmax = {col_name}")
    plt.show()
    
    # --- 4. Optional Heatmaps ---
    for softmax_type, group_df in df.groupby('gumbel_softmax'):
        heatmap_df = group_df.pivot(index='hidden_dim', columns='heads', values='delta_loglik')
        plt.figure(figsize=(6, 5))
        sns.heatmap(heatmap_df, annot=True, fmt=".2f", cmap="coolwarm")
        plt.title(f"ΔLogLik Heatmap - {'Gumbel' if softmax_type else 'Classic'} Softmax")
        plt.xlabel("Number of Heads")
        plt.ylabel("Hidden Dimension")
        plt.show()


In [None]:
plot_factorial_delta_loglik(delta_loglik_dict)


In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Convert dictionary to DataFrame
df = pd.DataFrame([
    {'model': name, 'delta_loglik': delta}
    for name, delta in delta_loglik_dict.items()
])

# Sort by ΔLogLik
df = df.sort_values('delta_loglik').reset_index(drop=True)
df['rank'] = df.index + 1  # for x-axis

# Plot
plt.figure(figsize=(10, 5))
plt.scatter(df['rank'], df['delta_loglik'], color='dodgerblue', s=100)
plt.xticks(df['rank'], df['model'], rotation=45, ha='right')
plt.xlabel("Model (sorted by ΔLogLik)")
plt.ylabel("ΔLogLik")
plt.title("Model improvements in Reading Time Prediction")
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd

# Convert to DataFrame
df = pd.DataFrame([
    {'model': name, 'delta_loglik': delta}
    for name, delta in delta_loglik_dict.items()
])

# Extract factors from the model name
df[['hidden_dim', 'heads', 'gumbel']] = df['model'].str.split('_', expand=True)
df['hidden_dim'] = df['hidden_dim'].astype(int)
df['heads'] = df['heads'].astype(int)
df['gumbel'] = df['gumbel'] == 'true'
# Hidden dimension effect
hidden_effect = df.groupby('hidden_dim')['delta_loglik'].mean().sort_values(ascending=False)
# Heads effect
heads_effect = df.groupby('heads')['delta_loglik'].mean().sort_values(ascending=False)
# Gumbel effect
gumbel_effect = df.groupby('gumbel')['delta_loglik'].mean().sort_values(ascending=False)

print("Hidden dimension effect (higher ΔLogLik = better):")
print(hidden_effect)
print("\nNumber of heads effect:")
print(heads_effect)
print("\nGumbel softmax effect:")
print(gumbel_effect)


In [None]:
import matplotlib.pyplot as plt

factors = ['hidden_dim', 'heads', 'gumbel']
effects = [hidden_effect, heads_effect, gumbel_effect]

plt.figure(figsize=(10,6))
for i, effect in enumerate(effects):
    plt.subplot(1, 3, i+1)
    effect.plot(kind='bar', color='skyblue')
    plt.title(f'{factors[i]} effect')
    plt.ylabel('Mean ΔLogLik')
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd

# Convert dictionary to DataFrame
df = pd.DataFrame([
    {'model': name, 'delta_loglik': delta}
    for name, delta in delta_loglik_dict.items()
])

# Extract factors from the model name
df[['hidden_dim', 'heads', 'gumbel']] = df['model'].str.split('_', expand=True)
df['hidden_dim'] = df['hidden_dim'].astype(int)
df['heads'] = df['heads'].astype(int)
df['gumbel'] = df['gumbel'] == 'true'

# --- Step 1: Compute mean ΔLogLik for each factor level ---
hidden_effect = df.groupby('hidden_dim')['delta_loglik'].mean().reset_index()
hidden_effect['factor'] = 'hidden_dim'
hidden_effect.rename(columns={'hidden_dim': 'level', 'delta_loglik': 'mean_delta'}, inplace=True)

heads_effect = df.groupby('heads')['delta_loglik'].mean().reset_index()
heads_effect['factor'] = 'heads'
heads_effect.rename(columns={'heads': 'level', 'delta_loglik': 'mean_delta'}, inplace=True)

gumbel_effect = df.groupby('gumbel')['delta_loglik'].mean().reset_index()
gumbel_effect['factor'] = 'gumbel_softmax'
gumbel_effect.rename(columns={'gumbel': 'level', 'delta_loglik': 'mean_delta'}, inplace=True)

# Combine all factors into a single DataFrame
effects_df = pd.concat([hidden_effect, heads_effect, gumbel_effect], ignore_index=True)

# --- Step 2: Rank all factor levels by mean ΔLogLik ---
effects_df = effects_df.sort_values('mean_delta', ascending=False).reset_index(drop=True)

print("Hierarchy of constraints by contribution to ΔLogLik:")
print(effects_df)
