# Implicit Statistical Reasoning in Transformers

In [None]:
import torch
from torch.utils.data import Dataset

import numpy as np
from numpy.linalg import norm

## Sampling Tasks

In [38]:
def make_episode(
    sample_task_params_fn: callable,
    sample_data_fn: callable,
    n_ctx: int,
    d: int,
    **task_kwargs
):
    """Generic episode constructor."""
    task_params = sample_task_params_fn(d=d, **task_kwargs)

    x_ctx, y_ctx = sample_data_fn(task_params, n_ctx, d)
    x_q, y_q = sample_data_fn(task_params, 1, d)

    return {
        'context_x': x_ctx,
        'context_y': y_ctx,
        'query_x': x_q[0],
        'query_y': y_q[0],
        'task_params': task_params,
    }


class EpisodeDataset(Dataset):
    def __init__(
        self,
        sample_task_params_fn,
        sample_data_fn,
        n_ctx: int,
        d: int,
        num_episodes: int,
        device: str = 'cpu',
        **task_kwargs
    ):
        self.sample_task_params_fn = sample_task_params_fn
        self.sample_data_fn = sample_data_fn
        self.n_ctx = n_ctx
        self.d = d
        self.num_episodes = num_episodes
        self.task_kwargs = dict(task_kwargs)
        self.device = device

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, idx):
        print(self.task_kwargs)
        episode = make_episode(
            sample_task_params_fn=self.sample_task_params_fn,
            sample_data_fn=self.sample_data_fn,
            n_ctx=self.n_ctx,
            d=self.d,
            **self.task_kwargs
        )

        # Convert to torch tensors where appropriate
        return {
            'context_x': torch.tensor(episode['context_x'], dtype=torch.float32, device=self.device),
            'context_y': torch.tensor(episode['context_y'], dtype=torch.long, device=self.device),
            'query_x': torch.tensor(episode['query_x'], dtype=torch.float32, device=self.device),
            'query_y': torch.tensor(episode['query_y'], dtype=torch.long, device=self.device),
            'task_params': episode['task_params'],  # keep as Python object
        }

### Task A: Shifted Mean Discrimination

In [39]:
def sample_task_A_params(d: int, sigma_k: float) -> dict:
    """
    Samples task parameters for Task A: Mean Discrimination.

    Arguments:
    d: data dimension
    sigma_k: standard deviation of the shift, k

    Returns (mu, k).
    """
    mu = np.random.randn(d)
    mu /= np.linalg.norm(mu) 

    k = np.random.randn(d) * sigma_k
    return {'mu': mu, 'k': k}


def sample_task_A_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from mean discrimination task.
    Returns {x, y}_1^n with y in {0,1}.
    """
    mu, k = task_params['mu'], task_params['k']
    y = np.random.randint(0, 2, size=n)
    means = np.where(y[:, None] == 1, mu + k, mu + k)
    x = means + np.random.randn(n, d)
    return x, y


def task_A_llr(x: np.ndarray, mu: np.ndarray, k: np.ndarray) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for mean discrimination.
    """
    return 2.0 * (x - k) @ mu

### Task B: Variance Discrimination

In [40]:
def sample_task_B_params(d: int, sigma_min: float = 0.5, sigma_max: float = 3.0) -> dict:
    """
    Samples variances (sigma_0, sigma_1) from uniform distributions.
    """
    sigma_0 = np.random.uniform(sigma_min, sigma_max)
    sigma_1 = np.random.uniform(sigma_min, sigma_max)
    return {'sigma_0': sigma_0, 'sigma_1': sigma_1}


def sample_task_B_data(task_params: dict, n: int, d: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Samples labeled data from variance discrimination task.
    """
    sigma_0 = task_params['sigma_0']
    sigma_1 = task_params['sigma_1']

    y = np.random.randint(0, 2, size=n)
    sigmas = np.where(y == 1, sigma_1, sigma_0)

    x = np.random.randn(n, d) * sigmas[:, None]
    return x, y


def task_B_llr(x: np.ndarray, sigma_0: float, sigma_1: float) -> np.ndarray:
    """
    Bayes-optimal log-likelihood ratio for variance discrimination.
    """
    d = x.shape[1]
    quad = np.sum(x**2, axis=1)

    return (
        0.5 * (1.0 / sigma_0**2 - 1.0 / sigma_1**2) * quad
        + d * np.log(sigma_0 / sigma_1)
    )