In [189]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
from transformer.transformer import DataLoader    
from utils.parse_data import load_trained_model
project_root = os.path.abspath("..")

In [190]:
#1.  Load the model that was trained in run_2
RUN = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, model_info, cfg = load_trained_model(
    run        = RUN,
    model_name = "model_seen9M",
    device     = device,
)
print(f"Loaded model {model_info['model_name']} on {device}  "
      f"(embedding={cfg.n_embd}, block_size={cfg.block_size})")

Loaded model model_seen9M on cpu  (embedding=4, block_size=6)


  model.load_state_dict(torch.load(model_path, map_location=device, **kwargs))


In [191]:
# 2. Build a dataloader with the same B and T used during training
B, T = 300, 6 
loader = DataLoader(
    B             = B,
    T             = T,
    process_rank  = 0,
    num_processes = 1,
    run_number    = RUN,
    suffix        = 'tr',
)


read in 100000 tokens from /Users/corwin/Building/Learning/Transformers_for_Modeling_Decision_Sequences/experiments/run_2/seqs/behavior_run_2tr.txt


In [192]:
# 3. Collect residuals, log-odds, and sequences using passive estimator
from synthetic_data_generation.agent import RFLR_mouse

N_BATCHES = 10  # collect 3000 sequences
all_residuals = []
all_passive_logodds = []  # Ground truth log odds from passive estimator
all_sequences = []  # store the actual RrLlR sequences

# Create an agent instance for passive estimation
# You may need to adjust these parameters based on your training setup
agent = RFLR_mouse(alpha=0.78, beta=2.05, tau=1.43)

model.eval()
with torch.no_grad():
    for batch_idx in range(N_BATCHES):
        x, _ = loader.next_batch()
        x = x.to(device)

        logits, loss, residual = model(x, return_residual=True)
        
        # Reshape residual from [B, T, n_embd] to [n_embd, T] for each batch
        # residual shape: [B, T, n_embd] -> [B, n_embd, T] -> Bx4x6 in our case
        residual_reshaped = residual.permute(0, 2, 1)

        for b in range(residual_reshaped.size(0)):
            all_residuals.append(residual_reshaped[b])            # [4,6]
            
            # Convert tensor indices back to R/L characters and choices/rewards
            vocab = ['R', 'r', 'L', 'l']
            itos = {i: ch for i, ch in enumerate(vocab)}
            sequence_chars = tuple(itos[int(idx)] for idx in x[b])
            all_sequences.append(sequence_chars)               # [6] - actual sequence as chars
            
            # Convert sequence to choices and rewards for passive estimator
            choices = []
            rewards = []
            for char in sequence_chars:
                if char in ['R', 'r']:
                    choices.append(1)  # Right choice
                    rewards.append(1 if char == 'R' else 0)  # Reward if uppercase
                else:  # char in ['L', 'l']
                    choices.append(0)  # Left choice  
                    rewards.append(1 if char == 'L' else 0)  # Reward if uppercase
            
            agent = RFLR_mouse(alpha=0.78, beta=2.05, tau=1.43)
            
            for c, r in zip(choices, rewards):
                agent.update_phi(c, r)
            
            logodds_next = agent.compute_log_odds(choices[-1])
            all_passive_logodds.append(logodds_next)

# ------------------------------------------------------------------
# 4.  Verify every snapshot is n_embd×6 format
# ------------------------------------------------------------------
expected_seq_len = 6
ok = all(residual.shape[1] == expected_seq_len for residual in all_residuals)
print(f"\nAll residual snapshots have sequence length {expected_seq_len}:", ok)
print(f"Total residuals collected: {len(all_residuals)}")
print(f"Each residual shape: {all_residuals[0].shape if all_residuals else 'None'}")
assert ok, f"Found a residual whose sequence length ≠ {expected_seq_len}"



All residual snapshots have sequence length 6: True
Total residuals collected: 3000
Each residual shape: torch.Size([4, 6])


In [193]:
# 5. Deduplicate sequences and show summary
unique = {}
for i, seq in enumerate(all_sequences):
    if seq not in unique:
        unique[seq] = i  # store first occurrence index

unique_indices = list(unique.values())
print(f"Total sequences collected: {len(all_sequences)}")
print(f"Number of unique sequences: {len(unique_indices)}")

# Print heads of each
print("\nFirst 5 unique sequences:")
for idx in unique_indices[:5]:
    print(f"Sequence: {all_sequences[idx]}")
    print(f"Residual: {all_residuals[idx]}")
    print(f"Log odds: {all_passive_logodds[idx]}")
    print()

Total sequences collected: 3000
Number of unique sequences: 597

First 5 unique sequences:
Sequence: ('R', 'r', 'r', 'L', 'r', 'r')
Residual: tensor([[-0.3820, -0.0530, -0.0497,  0.4367, -0.1079, -0.0750],
        [ 0.3729, -0.2439, -0.4258, -0.0364, -0.2951, -0.2367],
        [ 1.1575,  0.8633,  0.6045, -0.2724,  0.2323,  0.2203],
        [-1.0677, -0.1992,  0.0508, -0.1686,  0.2850,  0.2911]])
Log odds: 0.3434194814002114

Sequence: ('r', 'r', 'L', 'r', 'r', 'L')
Residual: tensor([[ 0.0822,  0.0459,  0.5809, -0.0447, -0.1062,  0.5447],
        [-0.3669, -0.4167, -0.2044, -0.2781, -0.2463, -0.1163],
        [ 0.6393,  0.5874, -0.3218,  0.2317,  0.1786, -0.5007],
        [-0.0878,  0.1738, -0.0481,  0.2683,  0.3029,  0.0870]])
Log odds: -3.0740335652556006

Sequence: ('r', 'L', 'r', 'r', 'L', 'l')
Residual: tensor([[ 0.0822,  0.6799,  0.0083, -0.0493,  0.5220,  0.4793],
        [-0.3669, -0.1644, -0.3208, -0.2920, -0.1858,  0.2457],
        [ 0.6393, -0.2106,  0.2960,  0.2452, -0.4906,

In [194]:
# 6. Find sequences with max and min log odds
logodds_array = np.array(all_passive_logodds)
max_idx = np.argmax(logodds_array)
min_idx = np.argmin(logodds_array)

print(f"Max log odds: {logodds_array[max_idx]:.4f}")
print(f"Min log odds: {logodds_array[min_idx]:.4f}")
print(f"Sequence for max log odds: {all_sequences[max_idx]}")
print(f"Sequence for min log odds: {all_sequences[min_idx]}")


Max log odds: 4.8012
Min log odds: -4.7861
Sequence for max log odds: ('R', 'R', 'R', 'R', 'R', 'R')
Sequence for min log odds: ('L', 'L', 'L', 'L', 'L', 'L')


In [195]:
# 7. Train linear regression: residual last column -> passive log odds
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

# 1. Get unique indices, shuffle, and split
unique_indices = list(unique.values())
np.random.seed(42)
np.random.shuffle(unique_indices)
n_test = len(unique_indices) // 4
test_indices = unique_indices[:n_test]
train_indices = unique_indices[n_test:]

# 2. Prepare X (features) and y (targets)
def get_features_targets(indices):
    X = []
    y = []
    for idx in indices:
        # Residual: shape [n_embd, seq_len], take last column
        X.append(all_residuals[idx][:, -1].cpu().numpy())
        # Passive log odds: scalar value
        y.append(all_passive_logodds[idx])
    X = np.stack(X)
    y = np.array(y)
    return X, y

X_train, y_train = get_features_targets(train_indices)
X_test, y_test = get_features_targets(test_indices)

# 3. Train linear regression
reg = LinearRegression()
reg.fit(X_train, y_train)

# 4. Predict and evaluate
y_pred = reg.predict(X_test)
r2 = r2_score(y_test, y_pred)
print(f"Test R^2: {r2:.4f}")

# 5. Print a few predictions vs. true values with sequences
print(f"\nFirst 5 test predictions:")
for i in range(5):
    idx = test_indices[i]
    print(f"Sequence: {all_sequences[idx]}")
    print(f"Predicted: {y_pred[i]:.4f}, True: {y_test[i]:.4f}")
    print()


Test R^2: 0.9701

First 5 test predictions:
Sequence: ('r', 'r', 'r', 'r', 'r', 'L')
Predicted: -3.2829, True: -2.8225

Sequence: ('r', 'r', 'R', 'r', 'R', 'R')
Predicted: 3.6319, True: 4.1078

Sequence: ('l', 'L', 'l', 'L', 'l', 'l')
Predicted: -1.4295, True: -1.4037

Sequence: ('l', 'l', 'r', 'l', 'l', 'l')
Predicted: -0.6930, True: -0.7725

Sequence: ('L', 'r', 'R', 'R', 'r', 'R')
Predicted: 2.8983, True: 3.5332

