# Implicit Statistical Reasoning in Transformers

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import logging 
import pandas as pd
import os
from dataclasses import dataclass
import matplotlib.pyplot as plt


# Logging format: save logs to logs/ folder + console output
os.makedirs('logs', exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('logs/experiment.log'),
        logging.StreamHandler()
    ]
)


# Random seeds
def set_seed(seed: int = 0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(0)


def set_plotting_style():
    """Sets up standard plotting style."""
    plt.rcParams.update({
        'text.usetex': False,
        'font.family': 'serif',
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
        'figure.figsize': (10, 7),
        'axes.grid': True,
        'grid.alpha': 0.3,
        'lines.linewidth': 2.5,
    })
set_plotting_style()


# PyTorch device selection
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available() 
    else 'mps' if torch.backends.mps.is_available() 
    else 'cpu'
)
logging.info(f'Using device: {DEVICE}')

In [None]:
@dataclass
class DataConfig:
    """Controls the generation of synthetic tasks."""
    d: int = 16  # Input dimension
    n_ctx: int = 32  # Number of context pairs
    sigma_k: float = 3.0  # Task A: Shift magnitude
    sigma_min: float = 0.5  # Task B: Variance lower bound
    sigma_max: float = 3.0  # Task B: Variance upper bound
    train_episodes: int = 50000  # Size of training dataset
    val_episodes: int = 5000  # Size of validation dataset


@dataclass
class TrainConfig:
    """Controls the optimization loop."""
    seeds: list = [0, 1, 2]
    batch_size: int = 64
    epochs: int = 10
    lr: float = 3e-4
    device: torch.device = DEVICE


# Initialize Global Configs
data_cfg = DataConfig()
train_cfg = TrainConfig()

## Sampling Tasks

### Dataset Wrapper

In [None]:
def make_episode(
    sample_task_params_fn: callable,
    sample_data_fn: callable,
    n_ctx: int,
    d: int,
    **task_kwargs
):
    """Generic episode constructor."""
    task_params = sample_task_params_fn(d=d, **task_kwargs)

    x_ctx, y_ctx = sample_data_fn(task_params, n_ctx, d)
    x_q, y_q = sample_data_fn(task_params, 1, d)

    return {
        'context_x': x_ctx,
        'context_y': y_ctx,
        'query_x': x_q[0],
        'query_y': y_q[0],
        'task_params': task_params,
    }


class EpisodeDataset(Dataset):
    """PyTorch Dataset for generating episodes on-the-fly."""
    def __init__(
        self,
        sample_task_params_fn,
        sample_data_fn,
        n_ctx: int,
        d: int,
        num_episodes: int,
        device: torch.device = torch.device('cpu'),
        **task_kwargs
    ):
        self.sample_task_params_fn = sample_task_params_fn
        self.sample_data_fn = sample_data_fn
        self.n_ctx = n_ctx
        self.d = d
        self.num_episodes = num_episodes
        self.task_kwargs = dict(task_kwargs)
        self.device = device

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, idx):
        """
        Each batch from  yields:
        {
        'context_x': (B, n_ctx, d),
        'context_y': (B, n_ctx),
        'query_x':   (B, d),
        'query_y':   (B,),
        }
        """
        episode = make_episode(
            sample_task_params_fn=self.sample_task_params_fn,
            sample_data_fn=self.sample_data_fn,
            n_ctx=self.n_ctx,
            d=self.d,
            **self.task_kwargs
        )

        # Convert to torch tensors where appropriate
        return {
            'context_x': torch.tensor(episode['context_x'], dtype=torch.float32, device=self.device),
            'context_y': torch.tensor(episode['context_y'], dtype=torch.long, device=self.device),
            'query_x': torch.tensor(episode['query_x'], dtype=torch.float32, device=self.device),
            'query_y': torch.tensor(episode['query_y'], dtype=torch.long, device=self.device),
            'task_params': episode['task_params'],  # keep as Python object
        }

### Task A: Shifted Mean Discrimination

In [None]:
def sample_task_A_params(d: int, sigma_k: float) -> dict:
    """
    Samples task parameters for Task A: Mean Discrimination.

    Arguments:
    d: data dimension
    sigma_k: standard deviation of the shift, k

    Returns (mu, k).
    """
    mu = np.random.randn(d)
    mu /= np.linalg.norm(mu) 

    k = np.random.randn(d) * sigma_k
    return {'mu': mu, 'k': k}


def sample_task_A_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from mean discrimination task.
    Returns {x, y}_1^n with y in {0,1}.
    """
    mu, k = task_params['mu'], task_params['k']
    y = np.random.randint(0, 2, size=n)
    means = np.where(y[:, None] == 1, mu + k, -mu + k)
    x = means + np.random.randn(n, d)
    return x, y


def task_A_llr(x: np.ndarray, mu: np.ndarray, k: np.ndarray) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for mean discrimination.
    """
    return 2.0 * (x - k) @ mu

### Task B: Variance Discrimination

In [15]:
def sample_task_B_params(d: int, sigma_min: float = 0.5, sigma_max: float = 3.0) -> dict:
    """
    Samples variances (sigma_0, sigma_1) from uniform distributions.
    """
    sigma_0 = np.random.uniform(sigma_min, sigma_max)
    sigma_1 = np.random.uniform(sigma_min, sigma_max)
    return {'sigma_0': sigma_0, 'sigma_1': sigma_1}


def sample_task_B_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from variance discrimination task.
    """
    sigma_0 = task_params['sigma_0']
    sigma_1 = task_params['sigma_1']

    y = np.random.randint(0, 2, size=n)
    sigmas = np.where(y == 1, sigma_1, sigma_0)

    x = np.random.randn(n, d) * sigmas[:, None]
    return x, y


def task_B_llr(x: np.ndarray, sigma_0: float, sigma_1: float) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for variance discrimination.
    """
    d = x.shape[1]
    quad = np.sum(x**2, axis=1)

    return (
        0.5 * (1.0 / sigma_0**2 - 1.0 / sigma_1**2) * quad
        + d * np.log(sigma_0 / sigma_1)
    )

### Training and Validation Dataset

In [None]:
set_seed(0)
train_loader_A = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.train_episodes,
        sigma_k=data_cfg.sigma_k,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=True, drop_last=True
)
val_loader_A = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_k=data_cfg.sigma_k,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)
val_loader_A_OOD = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_A_params,
        sample_data_fn=sample_task_A_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_k=100,  # OOD shift magnitude (much larger than training)
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)
train_loader_B = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_B_params,
        sample_data_fn=sample_task_B_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.train_episodes,
        sigma_min=data_cfg.sigma_min,
        sigma_max=data_cfg.sigma_max,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=True, drop_last=True
)
val_loader_B = DataLoader(
    EpisodeDataset(
        sample_task_params_fn=sample_task_B_params,
        sample_data_fn=sample_task_B_data,
        n_ctx=data_cfg.n_ctx,
        d=data_cfg.d,
        num_episodes=data_cfg.val_episodes,
        sigma_min=data_cfg.sigma_min,
        sigma_max=data_cfg.sigma_max,
        device=train_cfg.device
    ),
    batch_size=train_cfg.batch_size, shuffle=False, drop_last=False
)

## Training and Evaluation

In [None]:
def train_step(model: nn.Module, batch: dict, optimizer: optim.Optimizer, loss_fn: nn.Module):
    model.train()
    optimizer.zero_grad()

    logits = model(
        context_x=batch['context_x'],
        context_y=batch['context_y'],
        query_x=batch['query_x'],
    )

    target = batch['query_y'].float().view(-1, 1) 
    loss = loss_fn(logits, target)
    
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        preds = (logits > 0).long()
        acc = (preds == batch['query_y']).float().mean().item()

    return loss.item(), acc


@torch.no_grad()
def eval_step(model: nn.Module, batch: dict, loss_fn: nn.Module):
    model.eval()

    logits = model(
        context_x=batch['context_x'],
        context_y=batch['context_y'],
        query_x=batch['query_x'],
    )

    target = batch['query_y'].float().view(-1, 1)
    loss = loss_fn(logits, target)
    
    preds = (logits > 0).long()
    acc = (preds == target.long()).float().mean().item()

    return loss.item(), acc, logits.cpu(), batch['query_y'].cpu()


def run_epoch(model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer = None):
    is_train = optimizer is not None
    loss_fn = nn.BCEWithLogitsLoss()

    total_loss, total_acc = 0.0, 0.0
    all_logits, all_labels = [], []

    for batch in dataloader:
        if is_train:
            loss, acc = train_step(model, batch, optimizer, loss_fn)
        else:
            loss, acc, logits, labels = eval_step(model, batch, loss_fn)
            all_logits.append(logits)
            all_labels.append(labels)

        total_loss += loss
        total_acc += acc

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_acc / len(dataloader)

    metrics = {'loss': avg_loss, 'acc': avg_acc}

    if not is_train:
        logits = torch.cat(all_logits)
        labels = torch.cat(all_labels)
        metrics['logits'] = logits
        metrics['labels'] = labels

    return metrics


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> list[dict]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    history = []

    for epoch in range(epochs):
        train_metrics = run_epoch(model, train_loader, optimizer)
        val_metrics = run_epoch(model, val_loader)

        # Store metrics for this epoch
        history.append({
            'epoch': epoch,
            'train_loss': train_metrics['loss'],
            'train_acc': train_metrics['acc'],
            'val_loss': val_metrics['loss'],
            'val_acc': val_metrics['acc']
        })

        logging.info(
            f"[Epoch {epoch:02d}] "
            f"Train loss={train_metrics['loss']:.4f}, acc={train_metrics['acc']:.3f} | "
            f"Val loss={val_metrics['loss']:.4f}, acc={val_metrics['acc']:.3f}"
        )

    return history


def run_multiseed_experiment(
    model_class: type[nn.Module],
    model_kwargs: dict,
    train_loader: DataLoader,
    val_loader: DataLoader,
    seeds: list[int] | None,
    epochs: int = 10,
    lr: float = 1e-3,
    device: torch.device = torch.device('cpu')
) -> pd.DataFrame:
    if seeds is None:
        seeds = [0, 1, 2]
    all_results = []
    model_name = model_class.__name__

    logging.info(f'Running experiment for model: {model_name} with seeds: {seeds}')
    
    for seed in seeds:
        logging.info(f'Starting training with seed: {seed}')
        set_seed(seed)

        # Re-initialize Model Fresh 
        model = model_class(**model_kwargs).to(device)
        history = train_model(model, train_loader, val_loader, epochs=epochs, lr=lr)

        # Add metadata and store
        for epoch_data in history:
            epoch_data['seed'] = seed
            epoch_data['model'] = model_name
            all_results.append(epoch_data)

    set_seed(0) 
    # Convert to pd df 
    return pd.DataFrame(all_results)

## Model Architectures

In [None]:
class StaticBaseline(nn.Module):
    def __init__(self, d_in, d_hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, 1)
        )
        
    def forward(self, context_x, context_y, query_x):
        # Context is intentionally ignored
        return self.net(query_x)


class ICLTransformer(nn.Module):
    def __init__(self, d_in, d_model=128, n_layers=2, n_heads=4, max_len=512):
        super().__init__()
        self.d_model = d_model
        
        # Embeddings
        # Projects input data x -> d_model
        self.x_embed = nn.Linear(d_in, d_model)
        # Embeds binary labels y -> d_model
        self.y_embed = nn.Embedding(2, d_model)
        # Learned positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        
        # Transformer Encoder (GPT-style)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=4*d_model, 
            dropout=0.0, # No dropout for synthetic tasks required
            batch_first=True,
            activation='gelu',
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
        
        # Prediction Head
        self.head = nn.Linear(d_model, 1)

    def forward(self, context_x, context_y, query_x):
        """
        Input:
            context_x: (B, N, d)
            context_y: (B, N)
            query_x:   (B, d)
        """
        B, N, _ = context_x.shape
        device = context_x.device
        
        # Embed inputs 
        ctx_x_emb = self.x_embed(context_x)  # (B, N, d_model)
        qry_x_emb = self.x_embed(query_x.unsqueeze(1))  # (B, 1, d_model)
        ctx_y_emb = self.y_embed(context_y)  # (B, N, d_model)
        
        # Interleave Sequence 
        # Sequence: [x1, y1, x2, y2, ..., xN, yN, x_query]
        # Total length = 2*N + 1
        seq_len = 2*N + 1
        seq_emb = torch.zeros(B, seq_len, self.d_model, device=device)
        
        # Evens are X, Odds are Y
        seq_emb[:, 0:2*N:2, :] = ctx_x_emb
        seq_emb[:, 1:2*N:2, :] = ctx_y_emb
        # Last token is Query X
        seq_emb[:, -1, :] = qry_x_emb.squeeze(1)
        
        # Add Position & Forward
        seq_emb = seq_emb + self.pos_embed[:, :seq_len, :]
        
        out = self.transformer(seq_emb)
        
        # Predict on last token
        # We only care about the representation of x_query
        last_token = out[:, -1, :]
        return self.head(last_token)

## PART I: The Neccessity of Context

## PART II: Recovery of Optimal Tests