In [8]:
import pandas as pd

In [9]:
stories = pd.read_csv('/scratch2/mrenaudin/Hard-CBR-RNN/all_stories.tok', sep = '\t')

In [None]:
stories

## Predict surprisal

### Load model

In [18]:
from original_cbr import CBR_RNN
checkpoint = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/checkpoints/epoch=49-step=565950.ckpt'
hparams_path = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/hparams.yaml'

In [21]:
import yaml
import torch
with open(hparams_path, 'r') as f:
            hparams = yaml.safe_load(f)

In [None]:
ntoken = hparams['vocab_size']+1  # vocabulary size
ninp = hparams['ninp']      # embedding dimension
nhid = hparams['nhid']      # hidden dimension
nlayers = hparams.get('nlayers', 1)  # number of layers
nheads = hparams.get('nheads', 1)    # number of attention heads
dropout = hparams.get('dropout', 0.5) # dropout rate

# Initialize model
model = CBR_RNN(
    ntoken=ntoken,
    ninp=ninp, 
    nhid=nhid,
    nlayers=nlayers,
    nheads=nheads,
    dropout=dropout
)

# Load checkpoint
checkpoint = torch.load(checkpoint, map_location='cpu')

# Extract state_dict (might be nested in Lightning checkpoint)
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
    # Remove Lightning module prefix if present
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
else:
    state_dict = checkpoint

model.load_state_dict(state_dict)
model.eval()

## New method

In [64]:
from original_cbr_lightning import CBRLanguageModel, WordTokenizer
import torch

In [65]:
checkpoint = '/scratch2/mrenaudin/Hard-CBR-RNN/job_007/lightning_logs/version_1198526/checkpoints/epoch=49-step=565950.ckpt'


In [66]:
def load_trained_model(checkpoint_path):
    """Load the trained Lightning model"""
    model = CBRLanguageModel.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model

In [67]:
model = load_trained_model(checkpoint)

In [68]:
def compute_surprisal_with_chunking(lightning_model, stories_df, tokenizer, chunk_size=35):
    """
    Compute surprisal using the same chunking as training (seq_len=35)
    
    Args:
        lightning_model: Loaded Lightning CBRLanguageModel
        stories_df: DataFrame with columns [word, zone, item]
        tokenizer: WordTokenizer object
        chunk_size: Sequence length used during training (35)
    
    Returns:
        DataFrame with added 'surprisal' column
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lightning_model.to(device)
    
    # Extract the actual CBR_RNN model
    model = lightning_model.model
    
    results = []
    
    # Process each story
    for story_id in sorted(stories_df['item'].unique()):
        story_data = stories_df[stories_df['item'] == story_id].sort_values('zone')
        words = story_data['word'].tolist()
        
        print(f"Processing story {story_id} ({len(words)} words)...")
        
        # Convert words to token IDs using the tokenizer
        token_ids = []
        for word in words:
            token_id = tokenizer['word2idx'].get(word, 0)#tokenizer.stoi.get(word, 0)  # 0 = <unk>
            token_ids.append(token_id)
        
        # Process in chunks of 35 tokens (matching training)
        story_surprisals = [float('nan')] * len(words)  # Initialize with NaN
        
        for start_idx in range(0, len(token_ids), chunk_size):
            end_idx = min(start_idx + chunk_size, len(token_ids))
            chunk_ids = token_ids[start_idx:end_idx]
            
            if len(chunk_ids) < 2:  # Need at least 2 tokens for prediction
                continue
            
            # Convert to tensor [chunk_len, batch_size=1]
            input_tensor = torch.tensor(chunk_ids).unsqueeze(1).to(device)
            
            # Initialize cache for this chunk
            initial_cache = model.init_cache(input_tensor)
            
            # Forward pass (same as training)
            with torch.no_grad():
                # Use same parameters as training
                forward_kwargs = {
                    'observation': input_tensor,
                    'initial_cache': initial_cache
                }
                
                # Add Gumbel parameters if model was trained with them
                if hasattr(lightning_model, 'use_gumbel_softmax') and lightning_model.use_gumbel_softmax:
                    forward_kwargs.update({
                        'temperature': 1.0,  # Use fixed temp for inference
                        'use_gumbel': False   # Don't use Gumbel for inference
                    })
                
                logits, states = model(**forward_kwargs)
                # logits shape: [chunk_len, 1, vocab_size]
                
                # Compute log probabilities
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                
                # Compute surprisal for each position in chunk
                for i in range(1, len(chunk_ids)):  # Skip first token (no context)
                    target_token_id = chunk_ids[i]
                    # Use prediction from step i-1 for token at step i
                    log_prob = log_probs[i-1, 0, target_token_id].item()
                    surprisal = -log_prob
                    
                    # Map back to story position
                    story_pos = start_idx + i
                    if story_pos < len(story_surprisals):
                        story_surprisals[story_pos] = surprisal
        
        # Collect results for this story
        for idx, (_, row) in enumerate(story_data.iterrows()):
            results.append({
                'word': row['word'],
                'zone': row['zone'],
                'item': row['item'],
                'surprisal': story_surprisals[idx]
            })
    
    return pd.DataFrame(results)

In [69]:
import pickle
with open('tokenizer.pkl', 'rb') as f:
            tokenizer = pickle.load(f)

In [None]:
tokenizer.keys()

In [None]:
surprisal = compute_surprisal_with_chunking(model, stories, tokenizer, chunk_size=35)

In [None]:
surprisal


## Prepare data for reading time prediction

In [73]:
rt_data = pd.read_csv('/scratch2/mrenaudin/Hard-CBR-RNN/processed_RTs.tsv', sep='\t')

In [None]:
rt_data

In [75]:
data = rt_data.merge(surprisal, on=['item', 'zone'], how='inner')

In [76]:
data['log_RT'] = np.log(data['RT'])                    # y variable
data['word_length'] = data['word_x'].str.len()           # create length from word

# Clean grouping variable - convert to sequential integers
data['subject'] = pd.Categorical(data['WorkerId']).codes

# Standardize predictors  
data['surprisal_z'] = (data['surprisal'] - data['surprisal'].mean()) / data['surprisal'].std()
data['length_z'] = (data['word_length'] - data['word_length'].mean()) / data['word_length'].std()

# Reset index to avoid indexing issues
data = data.reset_index(drop=True)

In [None]:
print(f"Data prepared: {len(data)} observations")
print(f"Number of subjects: {data['subject'].nunique()}")
print(f"Mean word length: {data['word_length'].mean():.1f} characters")
    

## Fit model

In [None]:
from statsmodels.regression.mixed_linear_model import MixedLM

y = data['log_RT'].values
X = data[['surprisal_z', 'length_z']].values
X = np.column_stack([np.ones(len(X)), X])  # Add intercept
groups = data['subject'].values

# Remove any NaN values
mask = ~(np.isnan(y) | np.isnan(X).any(axis=1) | np.isnan(groups))
y = y[mask]
X = X[mask]
groups = groups[mask]

print(f"Clean data: {len(y)} observations")

model = MixedLM(y, X, groups=groups)
result = model.fit()

In [80]:
def print_results(fitted_model):
    """Print key results"""
    print("="*50)
    print("READING TIME PREDICTION RESULTS")
    print("="*50)
    
    # Check if it's a fitted model
    if hasattr(fitted_model, 'summary'):
        print(fitted_model.summary().tables[1])  # Coefficients table only
        
        # Get coefficients using array indexing
        surprisal_coef = fitted_model.params[1]  # x1 = surprisal
        surprisal_p = fitted_model.pvalues[1]
        
        length_coef = fitted_model.params[2]     # x2 = word length  
        length_p = fitted_model.pvalues[2]
            
        print(f"\nSurprisal coefficient (x1): {surprisal_coef:.4f}")
        print(f"P-value: {surprisal_p:.4f}")
        print(f"Significant: {'Yes' if surprisal_p < 0.05 else 'No'}")
        
        print(f"\nWord length coefficient (x2): {length_coef:.4f}")
        print(f"P-value: {length_p:.4f}")
        
        # Interpretation
        if surprisal_coef > 0:
            print("✓ Higher surprisal → Longer reading times (expected)")
        else:
            print("⚠ Higher surprisal → Shorter reading times (unexpected)")
            print("  This suggests an issue with surprisal computation or model")
            
    else:
        print("Model not properly fitted!")
        print(f"Model type: {type(fitted_model)}")



In [None]:
print_results(result)