# Implicit Statistical Reasoning in Transformers

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.stats import pearsonr, spearmanr
import logging 
import pandas as pd
import os
import glob
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import math

# Logging format: save logs to logs/ folder + console output
os.makedirs('logs', exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('logs/experiment.log'),
        logging.StreamHandler()
    ]
)


# Random seeds
def set_seed(seed: int = 0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(0)


def set_plotting_style():
    """Sets up standard plotting style."""
    plt.rcParams.update({
        'text.usetex': False,
        'font.family': 'serif',
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
        'figure.figsize': (10, 7),
        'axes.grid': True,
        'grid.alpha': 0.3,
        'lines.linewidth': 2.5,
    })
set_plotting_style()


# PyTorch device selection
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available() 
    else 'mps' if torch.backends.mps.is_available() 
    else 'cpu'
)
logging.info(f'Using device: {DEVICE}')

In [None]:
@dataclass
class DataConfig:
    """Controls the generation of synthetic tasks."""
    d: int = 16  # Input dimension
    n_ctx: int = 32  # Number of context pairs
    sigma_k: float = 3.0  # Task A: Shift magnitude
    sigma_k_ood: float = 9.0  # Task A: OOD Shift magnitude
    sigma_min: float = 0.5  # Task B: Variance lower bound
    sigma_max: float = 3.0  # Task B: Variance upper bound
    train_episodes: int = 50000  # Size of training dataset
    val_episodes: int = 5000  # Size of validation dataset


@dataclass
class TrainConfig:
    """Controls the optimization loop."""
    seeds: list[int] = field(default_factory=lambda: [0, 1, 2])
    batch_size: int = 64
    epochs: int = 20
    lr: float = 3e-4
    device: torch.device = DEVICE


# Initialize Global Configs
data_cfg = DataConfig()
train_cfg = TrainConfig()

## Sampling Tasks

### Dataset Wrapper

In [None]:
def make_episode(
    sample_task_params_fn: callable,
    sample_data_fn: callable,
    n_ctx: int,
    d: int,
    **task_kwargs
):
    """Generic episode constructor."""
    task_params = sample_task_params_fn(d=d, **task_kwargs)

    x_ctx, y_ctx = sample_data_fn(task_params, n_ctx, d)
    x_q, y_q = sample_data_fn(task_params, 1, d)

    return {
        'context_x': x_ctx,
        'context_y': y_ctx,
        'query_x': x_q[0],
        'query_y': y_q[0],
        'task_params': task_params,
    }


class EpisodeDataset(Dataset):
    """PyTorch Dataset for generating episodes on-the-fly."""
    def __init__(
        self,
        sample_task_params_fn,
        sample_data_fn,
        n_ctx: int,
        d: int,
        num_episodes: int,
        device: torch.device = torch.device('cpu'),
        **task_kwargs
    ):
        self.sample_task_params_fn = sample_task_params_fn
        self.sample_data_fn = sample_data_fn
        self.n_ctx = n_ctx
        self.d = d
        self.num_episodes = num_episodes
        self.task_kwargs = dict(task_kwargs)
        self.device = device

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, idx):
        """
        Each batch from  yields:
        {
        'context_x': (B, n_ctx, d),
        'context_y': (B, n_ctx),
        'query_x':   (B, d),
        'query_y':   (B,),
        }
        """
        episode = make_episode(
            sample_task_params_fn=self.sample_task_params_fn,
            sample_data_fn=self.sample_data_fn,
            n_ctx=self.n_ctx,
            d=self.d,
            **self.task_kwargs
        )

        # Convert to torch tensors where appropriate
        return {
            'context_x': torch.as_tensor(episode['context_x'], dtype=torch.float32, device=self.device),
            'context_y': torch.as_tensor(episode['context_y'], dtype=torch.long, device=self.device),
            'query_x': torch.as_tensor(episode['query_x'], dtype=torch.float32, device=self.device),
            'query_y': torch.as_tensor(episode['query_y'], dtype=torch.float32, device=self.device),
            'task_params': episode['task_params'],  # keep as Python object
        }

### Task A: Shifted Mean Discrimination

In [None]:
def sample_task_A_params(d: int, sigma_k: float) -> dict:
    """
    Samples task parameters for Task A: Mean Discrimination.

    Arguments:
    d: data dimension
    sigma_k: standard deviation of the shift, k

    Returns (mu, k).
    """
    mu = np.random.randn(d)
    mu /= np.linalg.norm(mu) 

    k = np.random.randn(d) * sigma_k
    return {'mu': mu, 'k': k}


def sample_task_A_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from mean discrimination task.
    Returns {x, y}_1^n with y in {0,1}.
    """
    mu, k = task_params['mu'], task_params['k']
    y = np.random.randint(0, 2, size=n)
    means = np.where(y[:, None] == 1, mu + k, -mu + k)
    x = means + np.random.randn(n, d)
    return x, y


def task_A_llr(x: np.ndarray, mu: np.ndarray, k: np.ndarray) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for mean discrimination.
    """
    centered_x = x - k
    dot_prod = (mu * centered_x).sum(axis=1)
    return 2 * dot_prod

### Task B: Variance Discrimination

In [None]:
def sample_task_B_params(d: int, sigma_min: float = 0.5, sigma_max: float = 3.0) -> dict:
    """
    Samples variances (sigma_0, sigma_1) from uniform distributions.
    """
    sigma_0 = np.random.uniform(sigma_min, sigma_max)
    sigma_1 = np.random.uniform(sigma_min, sigma_max)
    return {'sigma_0': sigma_0, 'sigma_1': sigma_1}


def sample_task_B_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from variance discrimination task.
    """
    sigma_0 = task_params['sigma_0']
    sigma_1 = task_params['sigma_1']

    y = np.random.randint(0, 2, size=n)
    sigmas = np.where(y == 1, sigma_1, sigma_0)

    x = np.random.randn(n, d) * sigmas[:, None]
    return x, y


def task_B_llr(x: np.ndarray, sigma_0: float, sigma_1: float) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for variance discrimination.
    """
    d = x.shape[1]
    # Compute Norm Squared ||x||^2
    norm_sq = (x**2).sum(axis=1)
    # Compute Constant Bias term
    # (d/2) * ln(sigma_0^2 / sigma_1^2)
    # simplifes to d * ln(sigma_0 / sigma_1)
    bias = d * np.log(sigma_0 / sigma_1)
    # Compute Quadratic Term
    # coeff = 0.5 * (1/sigma_0^2 - 1/sigma_1^2)
    coeff = 0.5 * ((1.0 / sigma_0**2) - (1.0 / sigma_1**2))
    return bias + (coeff * norm_sq)

### Training and Validation Dataset

In [None]:
set_seed(0)
train_loader_A_no_nuisance = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.train_episodes,
        sigma_k=0.0,  # No nuisance shift
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=True, drop_last=True
)
val_loader_A_no_nuisance = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_k=0.0,  # No nuisance shift
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)
train_loader_A = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.train_episodes,
        sigma_k=data_cfg.sigma_k,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=True, drop_last=True
)
val_loader_A = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_k=data_cfg.sigma_k,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)
val_loader_A_OOD = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_k=data_cfg.sigma_k_ood,  # OOD nuisance variables
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)
train_loader_B = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_B_params,
        sample_data_fn=sample_task_B_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.train_episodes,
        sigma_min=data_cfg.sigma_min,
        sigma_max=data_cfg.sigma_max,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=True, drop_last=True
)
val_loader_B = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_B_params,
        sample_data_fn=sample_task_B_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_min=data_cfg.sigma_min,
        sigma_max=data_cfg.sigma_max,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)

## Training and Evaluation

In [None]:
def train_step(model: nn.Module, batch: dict, optimizer: optim.Optimizer, loss_fn: nn.Module, scheduler: optim.lr_scheduler._LRScheduler = None):
    model.train()
    optimizer.zero_grad()

    logits = model(
        context_x=batch['context_x'],
        context_y=batch['context_y'],
        query_x=batch['query_x'],
    )

    target = batch['query_y'].float().view(-1, 1) 
    loss = loss_fn(logits, target)
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        acc = (preds == target).float().mean().item()

    return loss.item(), acc


@torch.no_grad()
def eval_step(model: nn.Module, batch: dict, loss_fn: nn.Module):
    model.eval()

    logits = model(
        context_x=batch['context_x'],
        context_y=batch['context_y'],
        query_x=batch['query_x'],
    )

    target = batch['query_y'].float().view(-1, 1)
    loss = loss_fn(logits, target)
    
    probs = torch.sigmoid(logits)
    preds = (probs > 0.5).float()
    acc = (preds == target).float().mean().item()

    return loss.item(), acc, logits.cpu(), target.cpu()


def run_epoch(
        model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer = None, scheduler: optim.lr_scheduler._LRScheduler = None
    ) -> dict:
    is_train = optimizer is not None
    loss_fn = nn.BCEWithLogitsLoss()

    total_loss = 0.0
    correct = 0
    total = 0
    all_logits, all_labels = [], []

    for batch in dataloader:
        if is_train:
            loss, acc = train_step(model, batch, optimizer, loss_fn, scheduler)
            logits = None
            labels = None
        else:
            loss, acc, logits, labels = eval_step(model, batch, loss_fn)
            all_logits.append(logits)
            all_labels.append(labels)

        B = batch['query_y'].numel()
        total_loss += loss * B
        correct += int(acc * B)
        total += B

    metrics = {
        'loss': total_loss / total,
        'acc': correct / total,
    }

    if not is_train:
        metrics['logits'] = torch.cat(all_logits)
        metrics['labels'] = torch.cat(all_labels)

    return metrics


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
    save_best: bool = False,
) -> tuple[list[dict], dict | None]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=lr, 
        epochs=epochs, 
        steps_per_epoch=len(train_loader),
        pct_start=0.3,  # Warmup for 30% of time
    )

    history = []
    best_val_loss = float('inf')
    best_state = None

    for epoch in range(epochs):
        train_metrics = run_epoch(model, train_loader, optimizer, scheduler)
        val_metrics = run_epoch(model, val_loader)

        # Store metrics for this epoch
        history.append({
            'epoch': epoch,
            'train_loss': train_metrics['loss'],
            'train_acc': train_metrics['acc'],
            'val_loss': val_metrics['loss'],
            'val_acc': val_metrics['acc']
        })

        if save_best and val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_state = {
                k: v.detach().cpu()
                for k, v in model.state_dict().items()
            }

        logging.info(
            f"[Epoch {epoch:03d}] "
            f"Train loss={train_metrics['loss']:.4f}, acc={train_metrics['acc']:.3f} | "
            f"Val loss={val_metrics['loss']:.4f}, acc={val_metrics['acc']:.3f}"
        )

    return history, best_state


def run_multiseed_experiment(
    model_class: type[nn.Module],
    model_kwargs: dict,
    train_loader: DataLoader,
    val_loader: DataLoader,
    seeds: list[int] | None,
    epochs: int = 10,
    lr: float = 1e-3,
    device: torch.device = torch.device('cpu'),
    experiment_name: str = 'task',
    checkpoint_dir: str | Path = 'checkpoints',
    data_output_dir: str | Path = 'results',
) -> pd.DataFrame:
    if seeds is None:
        seeds = [0, 1, 2]

    checkpoint_dir = Path(checkpoint_dir)
    data_output_dir = Path(data_output_dir)

    all_results = []
    model_name = model_class.__name__

    logging.info(f'Running experiment for model: {model_name} with seeds: {seeds}')
    
    for seed in seeds:
        logging.info(f'Starting training with seed: {seed}')
        set_seed(seed)

        # Re-initialize Model Fresh 
        model = model_class(**model_kwargs).to(device)
        history, best_state = train_model(
            model,
            train_loader,
            val_loader,
            epochs=epochs,
            lr=lr,
            save_best=True,
        )

        # Save best model checkpoint
        ckpt_path = checkpoint_dir / experiment_name / f'{model_name}_seed{seed}.pt'
        os.makedirs(ckpt_path.parent, exist_ok=True)
        torch.save({
            'model_class': model_name,
            'model_kwargs': model_kwargs,
            'seed': seed,
            'best_state_dict': best_state,
        }, ckpt_path)
        logging.info(f'Saved checkpoint to {ckpt_path}')

        # Add metadata and store
        for epoch_data in history:
            epoch_data.update({
                'seed': seed,
                'model': model_name,
                'experiment': experiment_name,
            })
            all_results.append(epoch_data)

    # Restore global seed
    set_seed(0) 

    df = pd.DataFrame(all_results)
    # Save results to CSV
    data_output_dir = Path(data_output_dir)
    data_output_dir.mkdir(parents=True, exist_ok=True)
    csv_path = data_output_dir / experiment_name / f'{model_name}_results.csv'
    os.makedirs(csv_path.parent, exist_ok=True)
    df.to_csv(csv_path, index=False)
    logging.info(f'Saved results to {csv_path}')
    return df 

## Main Model Architecture

In [None]:
class ICLTransformer(nn.Module):
    def __init__(self, d_in, n_ctx, d_model=128, n_layers=2, n_heads=4):
        super().__init__()
        self.d_model = d_model
        self.n_ctx = n_ctx
        # Embeddings
        self.x_proj = nn.Linear(d_in, d_model)
        self.y_proj = nn.Linear(1, d_model)
        self.query_proj = nn.Linear(d_in, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, n_ctx + 1, d_model) * 0.02)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=4*d_model, 
            dropout=0.0, # No dropout for synthetic tasks required
            batch_first=True,
            activation='gelu',
            norm_first=False,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, 1)

    def forward(self, context_x, context_y, query_x):
        B = context_x.shape[0]
        N = context_x.shape[1]
        # Embed Context: Combine x and y
        ctx_emb = self.x_proj(context_x) + self.y_proj(context_y.unsqueeze(-1).float())
        # Embed Query: It has no y, so we just embed x
        q_emb = self.query_proj(query_x).unsqueeze(1) # [B, 1, D]
        # Concatenate: [Ctx_1, ..., Ctx_N, Query]
        seq = torch.cat([ctx_emb, q_emb], dim=1) # [B, N+1, D]
        # Add Positional Embeddings
        seq = seq + self.pos_emb[:, :seq.shape[1], :]
        # Transformer Pass
        out = self.transformer(seq)
        # Predict only on the Query token (the last one)
        query_out = out[:, -1, :]
        logits = self.head(query_out)
        return logits

## PART I: Recovery of Optimal Tests

In [None]:
# # Experiment: Task A (In sample validation)
# run_multiseed_experiment(
#     model_class=ICLTransformer,
#     model_kwargs={
#         'd_in': data_cfg.d,
#         'n_ctx': data_cfg.n_ctx,
#         'd_model': 128,
#         'n_layers': 2,
#         'n_heads': 4,
#     },
#     train_loader=train_loader_A,
#     val_loader=val_loader_A,
#     seeds=train_cfg.seeds,
#     epochs=train_cfg.epochs,
#     lr=train_cfg.lr,
#     device=train_cfg.device,
#     experiment_name='task_A_regular',
#     checkpoint_dir='checkpoints',
#     data_output_dir='results',
# )

# # Experiment: Task A (OOD validation)
# run_multiseed_experiment(
#     model_class=ICLTransformer,
#     model_kwargs={
#         'd_in': data_cfg.d,
#         'n_ctx': data_cfg.n_ctx,
#         'd_model': 128,
#         'n_layers': 2,
#         'n_heads': 4,
#     },
#     train_loader=train_loader_A,
#     val_loader=val_loader_A_OOD,
#     seeds=train_cfg.seeds,
#     epochs=train_cfg.epochs,
#     lr=train_cfg.lr,
#     device=train_cfg.device,
#     experiment_name='task_A_OOD',
#     checkpoint_dir='checkpoints',
#     data_output_dir='results',
# )

# # Experiment: Task B
# run_multiseed_experiment(
#     model_class=ICLTransformer,
#     model_kwargs={
#         'd_in': data_cfg.d,
#         'n_ctx': data_cfg.n_ctx,
#         'd_model': 128,
#         'n_layers': 2,
#         'n_heads': 4,
#     },
#     train_loader=train_loader_B,
#     val_loader=val_loader_B,
#     seeds=train_cfg.seeds,
#     epochs=train_cfg.epochs,
#     lr=train_cfg.lr,
#     device=train_cfg.device,
#     experiment_name='task_B_regular',
#     checkpoint_dir='checkpoints',
#     data_output_dir='results',
# )

## PART II: Failure Modalities

### Model Ablation Architectures

In [None]:
class ICLTransformerInterleavedEmbeddings(nn.Module):
    """ 
    Context provided as (x, y) pairs interleaved in sequence.
    Tests the inductive bias of using (x, y) together, which is broken by interleaving.
    """
    def __init__(self, d_in, n_ctx, d_model=128, n_layers=2, n_heads=4, max_len=512):
        assert max_len >= 2 * n_ctx + 1, 'max_len must be at least 2*n_ctx + 1'
        super().__init__()
        self.d_model = d_model
        
        # Embeddings
        # Projects input data x -> d_model
        self.x_embed = nn.Linear(d_in, d_model)
        # Embeds binary labels y -> d_model
        self.y_embed = nn.Embedding(2, d_model)
        # Learned positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        
        # Transformer Encoder (GPT-style)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=4*d_model, 
            dropout=0.0, # No dropout for synthetic tasks required
            batch_first=True,
            activation='gelu',
            norm_first=False,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, 1)

    def forward(self, context_x, context_y, query_x):
        """
        Input:
            context_x: (B, N, d)
            context_y: (B, N)
            query_x:   (B, d)
        """
        B, N, _ = context_x.shape
        device = context_x.device
        # Embed inputs 
        ctx_x_emb = self.x_embed(context_x)  # (B, N, d_model)
        qry_x_emb = self.x_embed(query_x.unsqueeze(1))  # (B, 1, d_model)
        ctx_y_emb = self.y_embed(context_y)  # (B, N, d_model)
        # Interleave Sequence 
        # Sequence: [x1, y1, x2, y2, ..., xN, yN, x_query]
        # Total length = 2*N + 1
        seq_len = 2*N + 1
        seq_emb = torch.zeros(B, seq_len, self.d_model, device=device)
        # Evens are X, Odds are Y
        seq_emb[:, 0:2*N:2, :] = ctx_x_emb
        seq_emb[:, 1:2*N:2, :] = ctx_y_emb
        # Last token is Query X
        seq_emb[:, -1, :] = qry_x_emb.squeeze(1)
        # Add Position & Forward
        seq_emb = seq_emb + self.pos_embed[:, :seq_len, :]
        out = self.transformer(seq_emb)
        # Predict on last token (repr of x_query)
        last_token = out[:, -1, :]
        return self.head(last_token)
    

class ICLTransformerNoLabels(nn.Module):
    """
    Context provided, but labels y are removed.
    Tests whether performance relies on (x, y) pairing.
    """
    def __init__(self, d_in, n_ctx, d_model=128, n_layers=2, n_heads=4):
        super().__init__()
        self.x_proj = nn.Linear(d_in, d_model)
        self.query_proj = nn.Linear(d_in, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, n_ctx + 1, d_model) * 0.02)

        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4*d_model,
            dropout=0.0,
            batch_first=True,
            activation='gelu',
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, 1)

    def forward(self, context_x, context_y, query_x):
        ctx_emb = self.x_proj(context_x)
        q_emb = self.query_proj(query_x).unsqueeze(1)
        seq = torch.cat([ctx_emb, q_emb], dim=1)
        seq = seq + self.pos_emb[:, :seq.shape[1], :]
        out = self.transformer(seq)
        return self.head(out[:, -1])


class ICLTransformerShuffledLabels(ICLTransformer):
    """
    Labels y are randomly permuted within each batch.
    Preserves label distribution but breaks x-y association.
    """
    def forward(self, context_x, context_y, query_x):
        B, N = context_y.shape
        perm = torch.randperm(N, device=context_y.device)
        shuffled_y = context_y[:, perm]
        return super().forward(context_x, shuffled_y, query_x)


class ICLTransformerShuffledContext(ICLTransformer):
    """
    Context (x, y) pairs are randomly permuted within each batch.
    Tests whether the model relies on the order of context points.
    """
    def forward(self, context_x, context_y, query_x):
        B, N, _ = context_x.shape
        perm = torch.randperm(N, device=context_x.device)
        return super().forward(
            context_x[:, perm],
            context_y[:, perm],
            query_x
        )
    

class ICLTransformerFrozenAttention(ICLTransformer):
    """
    Attention weights are frozen at initialization.
    Only value projections + MLPs train.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        for layer in self.transformer.layers:
            attn = layer.self_attn

            # Freeze Q, K, V projections
            for p in attn.in_proj_weight, attn.in_proj_bias:
                if p is not None:
                    p.requires_grad = False

            # Freeze output projection
            attn.out_proj.weight.requires_grad = False
            attn.out_proj.bias.requires_grad = False


class ICLTransformerFrozenQK(ICLTransformer):
    """
    Query and key projections frozen; values trainable.
    Tests whether matching is essential or only aggregation.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        for layer in self.transformer.layers:
            attn = layer.self_attn
            d = attn.embed_dim
            
            # We define a hook that zeroes out the gradients for the first 2*d rows
            # (which correspond to Q and K in the fused in_proj_weight)
            def get_qk_freeze_hook(d_dim):
                def hook(grad):
                    # grad shape is (3*d, d) for weight, (3*d) for bias
                    # We clone to ensure we don't modify the gradient buffer in place unexpectedly
                    new_grad = grad.clone()
                    # Zero out Q and K parts
                    new_grad[:2*d_dim] = 0.0
                    return new_grad
                return hook

            # Register the hook on the parameters
            attn.in_proj_weight.register_hook(get_qk_freeze_hook(d))
            
            if attn.in_proj_bias is not None:
                attn.in_proj_bias.register_hook(get_qk_freeze_hook(d))


class ICLTransformerFrozenPos(ICLTransformer):
    """
    Positional embeddings are frozen at initialization.
    Tests reliance on learned absolute position.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pos_emb.requires_grad = False


class ICLTransformerNoPos(ICLTransformer):
    """
    No positional information at all.
    Tests permutation-invariant aggregation.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        with torch.no_grad():
            self.pos_emb.zero_()
        self.pos_emb.requires_grad = False


class ICLTransformerNoisyLabels(ICLTransformer):
    """
    Injects random label noise during forward pass.
    Tests robustness of evidence aggregation.
    """
    def __init__(self, *args, noise_p=0.2, **kwargs):
        super().__init__(*args, **kwargs)
        self.noise_p = noise_p

    def forward(self, context_x, context_y, query_x):
        if self.training and self.noise_p > 0:
            noise = torch.rand_like(context_y.float()) < self.noise_p
            context_y = context_y.clone()
            context_y[noise] = 1 - context_y[noise]
        return super().forward(context_x, context_y, query_x)

### Ablation Runs on Task A

In [None]:
# ABLATIONS_TASK_A = [
#     ('interleaved', ICLTransformerInterleavedEmbeddings, {}),
#     ('no_labels', ICLTransformerNoLabels, {}),
#     ('shuffled_labels', ICLTransformerShuffledLabels, {}),
#     ('shuffled_context', ICLTransformerShuffledContext, {}),
#     ('frozen_attention', ICLTransformerFrozenAttention, {}),
#     ('frozen_qk', ICLTransformerFrozenQK, {}),
#     ('frozen_pos', ICLTransformerFrozenPos, {}),
#     ('no_pos', ICLTransformerNoPos, {}),
# ]
# for ablation_name, model_class, model_kwargs in ABLATIONS_TASK_A:
#     run_multiseed_experiment(
#         model_class=model_class,
#         model_kwargs={
#             'd_in': data_cfg.d,
#             'n_ctx': data_cfg.n_ctx,
#             'd_model': 128,
#             'n_layers': 2,
#             'n_heads': 4,
#             **model_kwargs,
#         },
#         train_loader=train_loader_A,
#         val_loader=val_loader_A,
#         seeds=train_cfg.seeds,
#         epochs=train_cfg.epochs,
#         lr=train_cfg.lr,
#         device=train_cfg.device,
#         experiment_name=f'task_A_ablation_{ablation_name}',
#         checkpoint_dir='checkpoints',
#         data_output_dir='results',
#     )


# run_multiseed_experiment(
#     model_class=ICLTransformer,
#     model_kwargs={
#         'd_in': data_cfg.d,
#         'n_ctx': 2 * data_cfg.n_ctx,
#         'd_model': 128,
#         'n_layers': 2,
#         'n_heads': 4,
#     },
#     train_loader=train_loader_A,
#     val_loader=val_loader_A,
#     seeds=train_cfg.seeds,
#     epochs=train_cfg.epochs,
#     lr=train_cfg.lr,
#     device=train_cfg.device,
#     experiment_name='task_A_ablation_context_increase',
#     checkpoint_dir='checkpoints',
#     data_output_dir='results',
# )


# NOISE_LEVELS = [0.1, 0.2, 0.4]
# for p in NOISE_LEVELS:
#     run_multiseed_experiment(
#         model_class=ICLTransformerNoisyLabels,
#         model_kwargs={
#             'd_in': data_cfg.d,
#             'n_ctx': data_cfg.n_ctx,
#             'd_model': 128,
#             'n_layers': 2,
#             'n_heads': 4,
#             'noise_p': p,
#         },
#         train_loader=train_loader_A,
#         val_loader=val_loader_A,
#         seeds=train_cfg.seeds,
#         epochs=train_cfg.epochs,
#         lr=train_cfg.lr,
#         device=train_cfg.device,
#         experiment_name=f'task_A_ablation_noisy_labels_p{p}',
#         checkpoint_dir='checkpoints',
#         data_output_dir='results',
#     )

## Analysis of Part I and II

### Summary Statistics and History Plots

In [None]:
def load_all_results_csvs(results_root='results') -> pd.DataFrame:
    paths = sorted(glob.glob(os.path.join(results_root, '**', '*.csv'), recursive=True))
    dfs = []
    for p in paths:
        try:
            df = pd.read_csv(p)
        except Exception as e:
            logging.info(f'[WARN] failed to read {p}: {e}')
            continue

        df['source_path'] = p

        if 'experiment' not in df.columns:
            df['experiment'] = os.path.basename(os.path.dirname(p))

        if 'seed' in df.columns:
            df['seed'] = df['seed'].astype(int)
        if 'epoch' in df.columns:
            df['epoch'] = df['epoch'].astype(int)

        dfs.append(df)

    if not dfs:
        raise RuntimeError(f'No CSVs found under: {results_root}')

    return pd.concat(dfs, ignore_index=True)


def mean_ci95_halfwidth(x: pd.Series):
    x = x.dropna().to_numpy(dtype=float)
    n = len(x)
    if n == 0:
        return np.nan, np.nan, 0
    mean = float(np.mean(x))
    if n == 1:
        return mean, np.nan, 1
    sd = float(np.std(x, ddof=1))
    half = 1.96 * sd / math.sqrt(n)
    return mean, half, n


def final_train_val_table(df_all: pd.DataFrame):
    required = {'experiment', 'model', 'seed', 'epoch', 'train_acc', 'val_acc'}
    missing = required - set(df_all.columns)
    if missing:
        raise ValueError(f'Missing required columns: {missing}')

    last_rows = (
        df_all.sort_values(['experiment', 'model', 'seed', 'epoch'])
              .groupby(['experiment', 'model', 'seed'], as_index=False)
              .tail(1)
    )

    out = []
    for (exp, model), g in last_rows.groupby(['experiment', 'model']):
        tr_m, tr_h, n1 = mean_ci95_halfwidth(g['train_acc'])
        va_m, va_h, n2 = mean_ci95_halfwidth(g['val_acc'])
        n = min(n1, n2)

        out.append({
            'experiment': exp,
            'model': model,
            'n_seeds': n,
            'train_acc_ci95': (f'{tr_m:.4f} ± {tr_h:.4f}' if not np.isnan(tr_h) else f'{tr_m:.4f}'),
            'val_acc_ci95':   (f'{va_m:.4f} ± {va_h:.4f}' if not np.isnan(va_h) else f'{va_m:.4f}'),
        })

    return pd.DataFrame(out).sort_values(['experiment', 'model']).reset_index(drop=True)


def plot_learning_curves_ci95(
    df_all: pd.DataFrame,
    experiment: str,
    model: str | None = None,
    out_path: str | None = None,
):
    required = {'experiment', 'epoch', 'seed', 'train_acc', 'val_acc'}
    missing = required - set(df_all.columns)
    if missing:
        raise ValueError(f'Missing required columns: {missing}')

    df = df_all[df_all['experiment'] == experiment].copy()
    if model is not None:
        if 'model' not in df.columns:
            raise ValueError('model column missing; cannot filter by model.')
        df = df[df['model'] == model].copy()
    if df.empty:
        raise ValueError(f'No rows found for experiment={experiment} model={model}')

    def agg_mean_ci(x):  # 95% CI 
        x = x.dropna().to_numpy(dtype=float)
        n = len(x)
        m = float(np.mean(x)) if n else np.nan
        if n <= 1:
            return pd.Series({'mean': m, 'ci': np.nan})
        sd = float(np.std(x, ddof=1))
        ci = 1.96 * sd / math.sqrt(n)
        return pd.Series({'mean': m, 'ci': ci})
    
    train_stats = df.groupby('epoch')['train_acc'].apply(agg_mean_ci).reset_index()
    val_stats = df.groupby('epoch')['val_acc'].apply(agg_mean_ci).reset_index()

    # flatten output 
    train_stats = train_stats.pivot(index='epoch', columns='level_1', values='train_acc').reset_index()
    val_stats = val_stats.pivot(index='epoch', columns='level_1', values='val_acc').reset_index()

    plt.figure()
    plt.plot(train_stats['epoch'], train_stats['mean'], label='Train Accuracy')
    plt.fill_between(
        train_stats['epoch'],
        train_stats['mean'] - train_stats['ci'].fillna(0),
        train_stats['mean'] + train_stats['ci'].fillna(0),
        alpha=0.2,
    )

    plt.plot(val_stats['epoch'], val_stats['mean'], label='Validation Accuracy')
    plt.fill_between(
        val_stats['epoch'],
        val_stats['mean'] - val_stats['ci'].fillna(0),
        val_stats['mean'] + val_stats['ci'].fillna(0),
        alpha=0.2,
    )

    clean_experiment_name = experiment.replace('_', ' ').title()
    title = f'{clean_experiment_name}' + (f' | {model}' if model is not None else '')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim(0.0, 1.0)
    plt.legend()

    if out_path is not None:
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, dpi=200, bbox_inches='tight')
        plt.close()
        logging.info(f'Saved plot to: {out_path}')
    else:
        plt.show()

In [None]:
df_all = load_all_results_csvs('results')
summary = final_train_val_table(df_all)
logging.info(summary.to_string(index=False))
os.makedirs('analysis_out', exist_ok=True)
summary.to_csv('analysis_out/final_train_val_acc_ci95.csv', index=False)
logging.info('Saved summary to analysis_out/final_train_val_acc_ci95.csv')

experiments = df_all['experiment'].unique()
models = df_all['model'].unique()
for exp in experiments:
    for mod in models:
        try:
            plot_learning_curves_ci95(
                df_all,
                experiment=exp,
                model=mod,
                out_path=f'analysis_out/{exp}_{mod}_learning_curve_ci95.png',
            )
        except ValueError:
            continue

### Logit vs LLR Regression

In [None]:
def load_model_from_checkpoint(ckpt_path, device: torch.device = DEVICE) -> nn.Module:
    checkpoint = torch.load(ckpt_path, map_location=device)
    
    # Get the class name and arguments
    class_name = checkpoint['model_class']
    kwargs = checkpoint['model_kwargs']
    state_dict = checkpoint['best_state_dict']
    
    logging.info(f'Loading {class_name} from seed {checkpoint["seed"]} from {ckpt_path}')
    
    # Map string name to actual Class object
    MODEL_MAP = {
        'ICLTransformer': ICLTransformer,
        'ICLTransformerNoLabels': ICLTransformerNoLabels,
        'ICLTransformerFrozenPos': ICLTransformerFrozenPos,
        'ICLTransformerNoPos': ICLTransformerNoPos,
        'ICLTransformerFrozenQK': ICLTransformerFrozenQK,
        'ICLTransformerShuffledLabels': ICLTransformerShuffledLabels,
        'ICLTransformerShuffledContext': ICLTransformerShuffledContext,
        'ICLTransformerNoisyLabels': ICLTransformerNoisyLabels,
    }
    
    if class_name not in MODEL_MAP:
        raise ValueError(f'Unknown model class: {class_name}')
    
    ModelClass = MODEL_MAP[class_name]
    
    # Re-instantiate the model
    model = ModelClass(**kwargs)
    
    # Load weights
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval() 
    
    return model

In [None]:
def extract_logits_and_llr(model: nn.Module, dataloader: DataLoader, llr_function: callable) -> pd.DataFrame: 
    records = []
    
    for batch in dataloader: 
        B = batch['query_y'].shape[0]
        with torch.no_grad(): 
            logits = model(
                context_x = batch['context_x'],
                context_y = batch['context_y'],
                query_x = batch['query_x'],
            )
            logits = logits.cpu().squeeze(-1).numpy()  # (B,)
        query_y = batch['query_y'].cpu().numpy()  # (B,)

        # Compute Bayes-optimal LLRs
        task_params = batch['task_params']
        task_params = {k: v.cpu().numpy() for k, v in task_params.items()}
        llrs = llr_function(batch['query_x'].cpu().numpy(), **task_params)  # (B,)
        for i in range(B):
            records.append({
                'model_logit': logits[i],
                'bayes_llr': llrs[i],
                'true_label': query_y[i],
            })
    return pd.DataFrame.from_records(records)    

### Logit vs LLR Regression

In [None]:
model_A = load_model_from_checkpoint('checkpoints/task_A_regular/ICLTransformer_seed0.pt', device=train_cfg.device)
model_B = load_model_from_checkpoint('checkpoints/task_B_regular/ICLTransformer_seed0.pt', device=train_cfg.device)

logging.info('Extracting Task A Data...')
df_A = extract_logits_and_llr(model_A, val_loader_A, task_A_llr)
logging.info('Extracting Task B Data...')
df_B = extract_logits_and_llr(model_B, val_loader_B, task_B_llr)

def analyze_and_plot_logits_regression(df, task_name, ax):
    # Metrics
    r_p, _ = pearsonr(df['bayes_llr'], df['model_logit'])
    rho_s, _ = spearmanr(df['bayes_llr'], df['model_logit'])
    
    # We downsample for scatter plot clarity if N is huge (e.g. > 2000 points)
    plot_df = df.sample(2000) if len(df) > 2000 else df
    
    sns.regplot(
        data=plot_df, 
        x='bayes_llr', 
        y='model_logit', 
        ax=ax, 
        scatter_kws={'alpha': 0.3, 's': 10}, 
        line_kws={'color': 'red'}
    )
    
    ax.set_title(f'{task_name}\nPearson $r={r_p:.3f}$ | Spearman $\\rho={rho_s:.3f}$')
    ax.set_xlabel('True Analytical LLR')
    ax.set_ylabel('Transformer Logit')
    ax.grid(True, alpha=0.3)
    
    logging.info(f'--- {task_name} ---')
    logging.info(f'Pearson r: {r_p:.4f}')
    logging.info(f'Spearman rho: {rho_s:.4f}')


fig, axes = plt.subplots(1, 2, figsize=(12, 5))
analyze_and_plot_logits_regression(df_A, 'Task A: Shifted Mean (Linear)', axes[0])
analyze_and_plot_logits_regression(df_B, 'Task B: Variance (Quadratic)', axes[1])
plt.tight_layout()
plt.savefig('analysis_out/logits_regression_A_B.png')
plt.show()

In [None]:
model_A_OOD = load_model_from_checkpoint('checkpoints/task_A_OOD/ICLTransformer_seed0.pt', device=train_cfg.device)

logging.info('Extracting OOD Logits for Task A')
df_ood = extract_logits_and_llr(model_A, val_loader_A_OOD, task_A_llr)

plt.figure(figsize=(6, 5))
analyze_and_plot_logits_regression(df_ood, fr'Task A (OOD): $\sigma_k={data_cfg.sigma_k_ood}$', plt.gca())
plt.savefig('analysis_out/logits_regression_A_OOD.png')
plt.show()

## PART III: Mechanistic Interprability

### Logit Lens

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import pearsonr, spearmanr


def run_logit_lens(model, val_loader, task_type, num_batches=100):
    model.eval()
    device = next(model.parameters()).device

    # Register Hooks to capture layer outputs
    layer_outputs = {}
    def hook_fn(name):
        def fn(module, input, output):
            # output shape: [Batch, SeqLen, D]
            # Capture the Query token (last position)
            query_repr = output[:, -1, :]
            layer_outputs[name] = query_repr.detach()
        return fn
    hooks = []
    # Hook output of each Transformer Layer
    for i, layer in enumerate(model.transformer.layers):
        h = layer.register_forward_hook(hook_fn(f'Layer {i + 1}'))
        hooks.append(h)

    all_results = []
    loader_iter = iter(val_loader)

    logging.info(f'Logit Lens ({task_type}): Running {num_batches} batches')

    try:
        for batch_idx in range(num_batches):
            # Batch Fetching
            try:
                batch = next(loader_iter)
            except StopIteration:
                loader_iter = iter(val_loader)
                batch = next(loader_iter)

            # Move Inputs
            ctx_x = batch['context_x'].to(device)
            ctx_y = batch['context_y'].to(device)
            qry_x = batch['query_x'].to(device)

            # Forward Pass (trigger hooks)
            with torch.no_grad():
                _ = model(ctx_x, ctx_y, qry_x)

            # Compute True LLR
            if task_type == 'A':
                p = batch['task_params']
                mu = p['mu'].to(device)
                k = p['k'].to(device)
                # Vectorized LLR A: 2 * mu^T (x - k)
                centered = qry_x - k
                llr = 2 * (mu * centered).sum(dim=1)
            else:
                # Task B: Use label as proxy (0 vs 1)
                llr = batch['query_y'].to(device).float()

            llr = llr.cpu().numpy()

            # Virtual Logits
            # Reconstruct Layer 0 (Input Embeddings) manually
            with torch.no_grad():
                if hasattr(model, 'y_proj'):
                    # Task A/B model
                    _ = model.x_proj(ctx_x) + model.y_proj(
                        ctx_y.unsqueeze(-1).float()
                    )
                else:
                    # Ablation models
                    _ = model.x_proj(ctx_x)

                q_emb = model.query_proj(qry_x).unsqueeze(1)
                layer_outputs['Layer 0 (Input)'] = q_emb.squeeze(1)

            layer_names = ['Layer 0 (Input)', 'Layer 1', 'Layer 2']

            for name in layer_names:
                if name in layer_outputs:
                    repr_vector = layer_outputs[name]  # [B, D]

                    # Decode using the Final Head
                    virtual_logit = model.head(repr_vector)
                    virtual_logit = virtual_logit.cpu().detach().numpy().flatten()

                    # Calculate Correlations
                    r, _ = pearsonr(llr, virtual_logit)
                    rho, _ = spearmanr(llr, virtual_logit)

                    # Store for DataFrame
                    all_results.append({
                        'Layer': name,
                        'Metric': 'Pearson',
                        'Value': r,
                        'Batch': batch_idx
                    })
                    all_results.append({
                        'Layer': name,
                        'Metric': 'Spearman',
                        'Value': rho,
                        'Batch': batch_idx
                    })

    finally:
        # Cleanup hooks immediately
        for h in hooks:
            h.remove()

    if all_results:
        df = pd.DataFrame(all_results)

        # Print text summary
        logging.info('\nMean Correlations:')
        logging.info(df.groupby(['Layer', 'Metric'])['Value'].mean())

        plt.figure(figsize=(6, 4))
        sns.barplot(
            data=df,
            x='Layer',
            y='Value',
            hue='Metric',
            palette='Blues_d',
            errorbar=('ci', 95),
            capsize=0.1
        )

        plt.title(f'Logit Lens: When is the decision made?')
        plt.ylabel('Correlation with True LLR')
        plt.ylim(-0.2, 1.1)
        plt.grid(axis='y', alpha=0.3)
        plt.legend(loc='upper left')
        
        save_path = f'analysis_out/logit_lens_{task_type}.png'
        plt.savefig(save_path)
        print(f'Saved plot to {save_path}')
        plt.show()


run_logit_lens(model_A, val_loader_A, 'A')

### Kernel Regression Correlation

In [None]:
def extract_kernel_regression_comparison(
    model, val_loader, task_type='A', num_batches=20
):
    """
    Compares the Transformer's output against a hard-coded Nadaraya-Watson
    Kernel Regression (Soft Nearest Neighbor) estimator.
    """
    logging.info(f'Kernel Regression ({task_type}): Running {num_batches} batches')
    
    model.eval()
    device = next(model.parameters()).device
    
    # Storage for aggregation
    all_nn_logits = []
    all_algo_outputs = []

    # Determine number of heads for scaling
    try:
        layer0 = model.transformer.layers[0]
        if hasattr(layer0, 'self_attn'):
            n_heads = layer0.self_attn.num_heads
        elif hasattr(layer0, 'attn'):
            n_heads = layer0.attn.num_heads
        else:
            n_heads = 4  # Fallback
    except AttributeError:
        n_heads = 4  # Fallback
        
    d_head = model.d_model // n_heads
    scale = 1.0 / np.sqrt(d_head)

    loader_iter = iter(val_loader)
    try:
        for _ in range(num_batches):
            # Batch Fetching
            try:
                batch = next(loader_iter)
            except StopIteration:
                loader_iter = iter(val_loader)
                batch = next(loader_iter)

            # Move to device
            xc = batch['context_x'].to(device)  # [B, N, D]
            yc = batch['context_y'].to(device)  # [B, N]
            xq = batch['query_x'].to(device)  # [B, D]
            
            # Run transformer to get logits
            with torch.no_grad():
                logits = model(xc, yc, xq).cpu().numpy().flatten()
                all_nn_logits.append(logits)
                
            # Run the symbolic algorithm (Kernel Regression)
            # Hypothesis: Prediction = Sum( Softmax(q @ k.T / sqrt(d)) * y )
            with torch.no_grad():
                # Compute Dot Product Similarities (The Kernel)
                # [B, 1, D] @ [B, D, N] -> [B, 1, N]
                xq_uns = xq.unsqueeze(1)
                xc_T = xc.transpose(1, 2)
                
                # We use raw dot product as a proxy for the learned kernel
                raw_scores = torch.bmm(xq_uns, xc_T).squeeze(1) * scale  # [B, N]
                
                # Softmax (The Attention Mechanism)
                weights = F.softmax(raw_scores, dim=-1)  # [B, N]
                
                # Weighted Sum of Labels (The Value Aggregation)
                # Convert 0/1 labels to -1/+1 polarity for aggregation
                y_polar = (2 * yc - 1).float()
                
                # Algorithm Output
                algo_out = (weights * y_polar).sum(dim=-1).cpu().numpy()
                all_algo_outputs.append(algo_out)
                
    except Exception as e:
        logging.error(f'Error during extraction loop: {e}')
        return

    # Aggregate and compute correlation
    nn_logits_flat = np.concatenate(all_nn_logits)
    algo_outputs_flat = np.concatenate(all_algo_outputs)
    corr = np.corrcoef(nn_logits_flat, algo_outputs_flat)[0, 1]
    
    plt.figure(figsize=(6, 5))
    # Downsample for plotting if we have too many points (e.g., > 2000)
    if len(nn_logits_flat) > 2000:
        indices = np.random.choice(len(nn_logits_flat), 2000, replace=False)
        plot_x = algo_outputs_flat[indices]
        plot_y = nn_logits_flat[indices]
    else:
        plot_x = algo_outputs_flat
        plot_y = nn_logits_flat
    plt.scatter(plot_x, plot_y, alpha=0.3, s=15, c='purple', edgecolors='none')
    if len(plot_x) > 1:
        m, b = np.polyfit(plot_x, plot_y, 1)
        plt.plot(
            plot_x, m * plot_x + b, 
            'k--', lw=1.5, 
            label=f'Linear Fit'
        )
    plt.title(
        f'Transformer vs. Kernel Regression (R={corr:.4f})'
    )
    plt.xlabel(r'Symbolic Output: $\sum \text{softmax}(x_q^\top x_i) \cdot y_i$')
    plt.ylabel('Transformer Logit')
    plt.legend(loc='upper left')
    plt.grid(True, alpha=0.3)
    
    save_path = f'analysis_out/kernel_regression_{task_type}.png'
    plt.savefig(save_path)
    logging.info(f'Extraction Correlation: {corr:.4f}')
    plt.show()


extract_kernel_regression_comparison(model_A, val_loader_A, 'A', num_batches=50)

### OV-Circuit Alignment

In [None]:
def analyze_ov_circuit_alignment(model, task_type='A'):
    """
    Analyzes the alignment between the OV Circuit (W_O * W_V) and the 
    Input/Output embeddings.
    """
    logging.info(f'OV Circuit Alignment Analysis ({task_type})')
    
    model.eval()
    # Get the static vectors
    if not hasattr(model, 'y_proj'):
        logging.warning('Model has no y_proj; skipping.')
        return

    # Move everything to CPU for analysis to avoid device mismatches
    # The "Class 1" Input Vector
    y_in = model.y_proj.weight.detach().cpu().squeeze().t() 
    y_in = F.normalize(y_in, dim=0)
    # The "Class 1" Output Vector
    logit_out = model.head.weight.detach().cpu().squeeze()
    logit_out = F.normalize(logit_out, dim=0)
    
    # Iterate Through All Heads
    # We assume standard PyTorch implementation logic here
    n_layers = len(model.transformer.layers)
    n_heads = model.transformer.layers[0].self_attn.num_heads
    d_model = model.d_model
    head_dim = d_model // n_heads
    
    results = []
    
    for l in range(n_layers):
        # Get Attention Layer Weights
        attn = model.transformer.layers[l].self_attn
        
        # in_proj_weight is [3*D, D] -> [Q | K | V]
        # We need V (index 2)
        qkv = attn.in_proj_weight.detach().cpu()
        w_v_all = qkv[2*d_model : 3*d_model, :] # [D, D]
        w_o_all = attn.out_proj.weight.detach().cpu() # [D, D]
        
        for h in range(n_heads):
            # Slice V and O for this specific head
            row_start = h * head_dim
            row_end = (h+1) * head_dim
            # W_V_h: [Head_Dim, D_Model]
            w_v_h = w_v_all[row_start:row_end, :]
            # W_O_h: [D_Model, Head_Dim]
            w_o_h = w_o_all[:, row_start:row_end]
            # Compute the OV Circuit (W_O * W_V)
            # This matrix describes: "If I attend to vector x, what do I write to the residual stream?"
            w_ov_h = torch.matmul(w_o_h, w_v_h) # [D, D]
            # Push the Label Vector through the Circuit
            # "If I attend to a token that IS the label, what do I write?"
            v_written = torch.matmul(w_ov_h, y_in)
            # Check Alignment with the Logit Vector
            # +1.0 = Perfect Copying (Induction)
            # -1.0 = Perfect Suppression (Anti-Induction)
            alignment = F.cosine_similarity(v_written, logit_out, dim=0).item()
            results.append({
                'Layer': l,
                'Head': h,
                'Alignment': alignment
            })

    df = pd.DataFrame(results)
    pivot = df.pivot(index='Layer', columns='Head', values='Alignment')
    logging.info(f'Flattened OV table: {pivot}')
    plt.figure(figsize=(6, 3))
    sns.heatmap(pivot, annot=True, cmap='RdBu_r', center=0, fmt='.2f')
    plt.title(f'OV-Circuit Alignment ({task_type})\n(Does Head map Labels $\\to$ Logits?)')
    plt.ylabel('Layer')
    plt.xlabel('Head')
    plt.tight_layout()
    save_path = f'analysis_out/ov_circuit_alignment_{task_type}.png'
    plt.savefig(save_path, dpi=300)
    plt.show()

In [None]:
model_A = load_model_from_checkpoint('checkpoints/task_A_regular/ICLTransformer_seed0.pt', device=train_cfg.device)
analyze_ov_circuit_alignment(model_A, 'A')