In [24]:
import pandas as pd
import numpy as np
import json
import torch.nn as nn
import torch
# if not hasattr(torch._dynamo, 'external_utils'):
#     import types
#     torch._dynamo.external_utils = types.ModuleType('external_utils')
#     torch._dynamo.external_utils.is_compiling = lambda: True

import pytorch_lightning as pl
import torch.nn.functional as F


In [5]:
from lit_cbr import CBRLanguageModel
from grid_search import WordTokenizer
import torch
from clean_RTs import prepare_data_correct, fit_simple_lmer, compute_delta_loglik_correct, run_paper_analysis_correct
from grid_search import WordTokenizer  # assuming your tokenizer class is here
import pandas as pd


In [6]:
stories = pd.read_csv('/scratch2/mrenaudin/Hard-CBR-RNN/all_stories.tok', sep = '\t')

In [None]:
stories

In [18]:
def load_trained_cbr(checkpoint_path):
    """Load the trained Lightning model"""
    model = CBRLanguageModel.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model

In [15]:
from entire_transformer import SimpleTransformerLM
def load_trained_transformer(checkpoint_path):
    model = SimpleTransformerLM.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model

In [9]:
def compute_surprisal_with_chunking_cbr(lightning_model, stories_df, tokenizer, chunk_size=64):
    """
    Compute surprisal using the same chunking as training (seq_len=35)
    
    Args:
        lightning_model: Loaded Lightning CBRLanguageModel
        stories_df: DataFrame with columns [word, zone, item]
        tokenizer: WordTokenizer object
        chunk_size: Sequence length used during training (35)
    
    Returns:
        DataFrame with added 'surprisal' column
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lightning_model.to(device)
    
    # Extract the actual CBR_RNN model
    model = lightning_model.model
    
    results = []
    
    # Process each story
    for story_id in sorted(stories_df['item'].unique()):
        story_data = stories_df[stories_df['item'] == story_id].sort_values('zone')
        words = story_data['word'].tolist()
        
        print(f"Processing story {story_id} ({len(words)} words)...")
        
        # Convert words to token IDs using the tokenizer
        token_ids = []
        for word in words:
            token_id = tokenizer.stoi.get(word, 0)  # 0 = <unk>
            token_ids.append(token_id)
        
        # Process in chunks of 35 tokens (matching training)
        story_surprisals = [float('nan')] * len(words)  # Initialize with NaN
        
        for start_idx in range(0, len(token_ids), chunk_size):
            end_idx = min(start_idx + chunk_size, len(token_ids))
            chunk_ids = token_ids[start_idx:end_idx]
            
            if len(chunk_ids) < 2:  # Need at least 2 tokens for prediction
                continue
            
            # Convert to tensor [chunk_len, batch_size=1]
            input_tensor = torch.tensor(chunk_ids).unsqueeze(1).to(device)
            
            # Initialize cache for this chunk
            initial_cache = model.init_cache(input_tensor)
            
            # Forward pass (same as training)
            with torch.no_grad():
                # Use same parameters as training
                forward_kwargs = {
                    'observation': input_tensor,
                    'initial_cache': initial_cache
                }
                
                # Add Gumbel parameters if model was trained with them
                if hasattr(lightning_model, 'use_gumbel_softmax') and lightning_model.use_gumbel_softmax:
                    forward_kwargs.update({
                        'temperature': 0.1,  # Use fixed temp for inference
                        'use_gumbel': True   # Don't use Gumbel for inference
                    })
                
                logits, states = model(**forward_kwargs)
                # logits shape: [chunk_len, 1, vocab_size]
                
                # Compute log probabilities
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                
                # Compute surprisal for each position in chunk
                for i in range(1, len(chunk_ids)):  # Skip first token (no context)
                    target_token_id = chunk_ids[i]
                    # Use prediction from step i-1 for token at step i
                    log_prob = log_probs[i-1, 0, target_token_id].item()
                    surprisal = -log_prob
                    
                    # Map back to story position
                    story_pos = start_idx + i
                    if story_pos < len(story_surprisals):
                        story_surprisals[story_pos] = surprisal
        
        # Collect results for this story
        for idx, (_, row) in enumerate(story_data.iterrows()):
            results.append({
                'word': row['word'],
                'zone': row['zone'],
                'item': row['item'],
                'surprisal': story_surprisals[idx]
            })
    
    return pd.DataFrame(results)

In [22]:
def compute_surprisal_with_chunking_transformer(lightning_model, stories_df, tokenizer, chunk_size=64):
    """
    Compute surprisal using the same chunking as training for SimpleTransformer model
    
    Args:
        lightning_model: Loaded Lightning SimpleTransformerLM
        stories_df: DataFrame with columns [word, zone, item]
        tokenizer: WordTokenizer object
        chunk_size: Sequence length used during training (64)
    
    Returns:
        DataFrame with added 'surprisal' column
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lightning_model.to(device)
    lightning_model.eval()  # Set to evaluation mode
    
    # Extract the actual SimpleTransformer model
    model = lightning_model.model
    
    results = []
    
    # Process each story
    for story_id in sorted(stories_df['item'].unique()):
        story_data = stories_df[stories_df['item'] == story_id].sort_values('zone')
        words = story_data['word'].tolist()
        
        print(f"Processing story {story_id} ({len(words)} words)...")
        
        # Convert words to token IDs using the tokenizer
        token_ids = []
        for word in words:
            token_id = tokenizer.stoi.get(word, 0)  # 0 = <unk>
            token_ids.append(token_id)
        
        # Process in chunks matching training
        story_surprisals = [float('nan')] * len(words)  # Initialize with NaN
        
        for start_idx in range(0, len(token_ids), chunk_size):
            end_idx = min(start_idx + chunk_size, len(token_ids))
            chunk_ids = token_ids[start_idx:end_idx]
            
            if len(chunk_ids) < 2:  # Need at least 2 tokens for prediction
                continue
            
            # Convert to tensor [seq_len, batch_size=1] (SimpleTransformer expects this format)
            input_tensor = torch.tensor(chunk_ids).unsqueeze(1).to(device)
            
            # Forward pass
            with torch.no_grad():
                # Prepare forward arguments
                forward_kwargs = {
                    'src': input_tensor
                }
                
                # Add Gumbel parameters if model was trained with them
                if hasattr(lightning_model, 'use_gumbel_softmax') and lightning_model.use_gumbel_softmax:
                    forward_kwargs.update({
                        'temperature': 0.5,  # Use moderate temp for inference
                        'use_gumbel': False,  # Don't use Gumbel sampling for surprisal
                        'hard': False
                    })
                
                # Get logits: [seq_len, batch_size=1, vocab_size]
                logits = model(**forward_kwargs)
                
                # Compute log probabilities
                log_probs = F.log_softmax(logits, dim=-1)
                
                # Compute surprisal for each position in chunk
                for i in range(1, len(chunk_ids)):  # Skip first token (no previous context)
                    target_token_id = chunk_ids[i]
                    
                    # Use prediction from step i-1 for token at step i
                    # logits[i-1, 0, :] contains predictions for position i
                    log_prob = log_probs[i-1, 0, target_token_id].item()
                    surprisal = -log_prob  # Surprisal = -log(probability)
                    
                    # Map back to story position
                    story_pos = start_idx + i
                    if story_pos < len(story_surprisals):
                        story_surprisals[story_pos] = surprisal
        
        # Collect results for this story
        for idx, (_, row) in enumerate(story_data.iterrows()):
            results.append({
                'word': row['word'],
                'zone': row['zone'],
                'item': row['item'],
                'surprisal': story_surprisals[idx]
            })
    
    return pd.DataFrame(results)


In [None]:
tokenizer = WordTokenizer.load('tokenizer.json')
checkpoints_cbr_unfinished = {
    '128_1_false_19':'/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_000/lightning_logs/version_1354621/checkpoints/epoch=19-step=123800.ckpt',
    '512_1_false_15' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_001/lightning_logs/version_1354622/checkpoints/epoch=15-step=99040.ckpt',
    '128_8_false_20' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_002/lightning_logs/version_1354623/checkpoints/epoch=20-step=129990.ckpt',
    '512_8_false_16' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_003/lightning_logs/version_1354624/checkpoints/epoch=16-step=105230.ckpt',
    '128_1_true_17' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_004/lightning_logs/version_1354625/checkpoints/epoch=17-step=111420.ckpt',
    '512_1_true_15' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_005/lightning_logs/version_1354626/checkpoints/epoch=15-step=99040.ckpt',
    '128_1_true_19' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_006/lightning_logs/version_1354627/checkpoints/epoch=19-step=123800.ckpt',
    '512_1_true_15' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_cbr_007/lightning_logs/version_1354619/checkpoints/epoch=15-step=99040.ckpt'   
}

# Placeholder dictionary to store ΔLogLik
delta_loglik_dict = {}

# Load your reading time and surprisal data
# (You need to have these prepared for your analysis)
rt_data = pd.read_csv('processed_RTs.tsv', sep='\t')

# Loop over checkpoints
for name, ckpt_path in checkpoints_cbr_unfinished.items():
    print(f"\nRunning analysis for checkpoint: {name}")
    model = load_trained_cbr(ckpt_path)
    
    surprisal_data = compute_surprisal_with_chunking(model, stories, tokenizer=tokenizer, chunk_size=64)
    
    # Run ΔLogLik analysis
    results = run_paper_analysis_correct(rt_data, surprisal_data)
    
    # Store ΔLogLik
    delta_loglik_dict[name] = results['delta_loglik']

# Print results
print("\n=== ΔLogLik for all checkpoints ===")
for name, delta in delta_loglik_dict.items():
    print(f"{name}: {delta:.2f}")


In [None]:
tokenizer = WordTokenizer.load('tokenizer.json')
checkpoints_transformer_unfinished = {
    '128_1_false':'/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_000/lightning_logs/version_1356201/checkpoints/epoch=49-step=309500.ckpt',
    '512_1_false' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_001/lightning_logs/version_1356202/checkpoints/epoch=49-step=309500.ckpt',
    '128_8_false' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_002/lightning_logs/version_1356203/checkpoints/epoch=49-step=309500.ckpt',
    '512_8_false' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_003/lightning_logs/version_1356204/checkpoints/epoch=49-step=309500.ckpt',
    '128_1_true' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_004/lightning_logs/version_1356205/checkpoints/epoch=49-step=309500.ckpt',
    '512_1_true' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_005/lightning_logs/version_1356206/checkpoints/epoch=49-step=309500.ckpt',
    '128_1_true' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_006/lightning_logs/version_1356207/checkpoints/epoch=49-step=309500.ckpt',
    '512_1_true' : '/scratch2/mrenaudin/Hard-CBR-RNN/final_models/job_transformer_007/lightning_logs/version_1356197/checkpoints/epoch=49-step=309500.ckpt'   
}

# Placeholder dictionary to store ΔLogLik
delta_loglik_dict = {}

# Load your reading time and surprisal data
# (You need to have these prepared for your analysis)
rt_data = pd.read_csv('processed_RTs.tsv', sep='\t')

# Loop over checkpoints
for name, ckpt_path in checkpoints_transformer_unfinished.items():
    print(f"\nRunning analysis for checkpoint: {name}")
    model = load_trained_transformer(ckpt_path)
    
    surprisal_data = compute_surprisal_with_chunking_transformer(model, stories, tokenizer=tokenizer)
    
    # Run ΔLogLik analysis
    results = run_paper_analysis_correct(rt_data, surprisal_data)
    
    # Store ΔLogLik
    delta_loglik_dict[name] = results['delta_loglik']

# Print results
print("\n=== ΔLogLik for all checkpoints ===")
for name, delta in delta_loglik_dict.items():
    print(f"{name}: {delta:.2f}")
