# Mixture of Factor Analyzers for LLM Activation Decomposition

**Paper**: [From Directions to Regions: Decomposing Activations in Language Models via Local Geometry](https://arxiv.org/abs/2602.02464) (Shafran et al., 2025)

## Core Idea

Existing interpretability methods assume concepts live along **single global directions** (linear probes, SAE features, DiffMean). But many concepts have **nonlinear** structure—they're spread across clusters with different local orientations.

**Mixture of Factor Analyzers (MFA)** models activation space as a collection of **Gaussian regions**, each with its own centroid and local low-rank subspace. Every activation decomposes into:

1. **Centroid** $\mu_k$: *which region* the activation belongs to (broad semantic category)
2. **Local offset** $W_k z_k$: *variation within* that region (fine-grained distinctions)

### Why this matters

- **Steering**: Centroid interpolation steers toward broad concepts; local offsets refine within a concept. MFA often outperforms SAEs on causal steering benchmarks.
- **Interpretability**: MFA decompositions have ~96% interpretable feature mass (vs ~29% for SAEs).
- **Structure discovery**: Reveals that concepts organize into multi-Gaussian neighborhoods (e.g., an "emotions" cluster contains sub-Gaussians for happiness, surprise, anger).

### The Model

**Factor Analysis (single component):**
$$x = \mu + Wz + \epsilon, \quad z \sim \mathcal{N}(0, I_R), \quad \epsilon \sim \mathcal{N}(0, \Psi)$$

where $W \in \mathbb{R}^{D \times R}$ maps $R$ latent factors to $D$-dimensional activation space, and $\Psi$ is diagonal noise.

**Mixture of Factor Analyzers:**
$$p(x) = \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x \mid \mu_k, W_k W_k^\top + \Psi)$$

Each component $k$ has its own centroid $\mu_k$ and loading matrix $W_k$, but they share noise $\Psi$.

**Decomposition** of an activation $x$:
$$x \approx \underbrace{\mu_k}_{\text{centroid}} + \underbrace{W_k \hat{z}_k}_{\text{local offset}} + \underbrace{\epsilon}_{\text{noise}}$$

where $k = \arg\max_j R_j(x)$ is the most likely component, and $\hat{z}_k$ is the posterior mean of the latent.

### Exercise Overview

You will:
1. **Extract activations** from a small LLM
2. **Implement Factor Analysis** (the single-component building block)
3. **Implement MFA** (mixture of FAs with K-means initialization)
4. **Decompose activations** into centroid + local offset + noise
5. **Interpret components**: inspect centroids and loading directions via logit lens
6. **Steer with MFA**: compare centroid intervention vs local offset intervention

Estimated time: ~45 minutes

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from jaxtyping import Float
from torch import Tensor
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: Extract Activations

We extract residual stream activations from a middle layer of SmolLM2-135M using `nnsight`. The paper uses Llama-3.1-8B and Gemma-2-2B with 100M activations; we use a much smaller setup.

In [None]:
from nnsight import LanguageModel
from datasets import load_dataset

model_name = "HuggingFaceTB/SmolLM2-135M"
model = LanguageModel(model_name, device_map=device, dispatch=True)
tokenizer = model.tokenizer
tokenizer.pad_token = tokenizer.eos_token

n_layers = model.config.num_hidden_layers
d_model = model.config.hidden_size
target_layer = n_layers // 3  # Early-mid layer

print(f"Model: {model_name}")
print(f"Layers: {n_layers}, d_model: {d_model}")
print(f"Extracting from layer {target_layer}")

In [None]:
ds = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True)

all_activations = []
all_tokens = []
num_texts = 300
max_len = 128

texts = []
for i, example in enumerate(ds):
    if i >= num_texts:
        break
    texts.append(example["text"][:512])

print(f"Collected {len(texts)} texts, tokenizing...")

batch_size = 20
for i in tqdm(range(0, len(texts), batch_size), desc="Extracting activations"):
    batch_texts = texts[i:i + batch_size]
    tokens = tokenizer(
        batch_texts, return_tensors="pt", padding=True,
        truncation=True, max_length=max_len
    )
    input_ids = tokens["input_ids"].to(device)
    attention_mask = tokens["attention_mask"].to(device)

    with torch.no_grad():
        with model.trace(input_ids, attention_mask=attention_mask):
            hidden = model.model.layers[target_layer].output.save()

    acts = hidden.float().cpu()
    mask = attention_mask.cpu().bool()

    for b in range(acts.shape[0]):
        valid = mask[b].clone()
        valid[0] = False
        if valid.sum() > 0:
            all_activations.append(acts[b, valid])
            all_tokens.append(input_ids[b, valid].cpu())

activations = torch.cat(all_activations, dim=0).float()
token_ids = torch.cat(all_tokens, dim=0)

print(f"\nCollected {activations.shape[0]:,} activations of dim {activations.shape[1]}")
print(f"Stats: mean={activations.mean():.3f}, std={activations.std():.3f}")

In [None]:
# Get the unembedding matrix for logit lens later
unembed = model.lm_head.weight.detach().cpu().float()  # (vocab_size, d_model)
# Apply final layer norm for proper logit lens
ln_weight = model.model.norm.weight.detach().cpu().float()

def logit_lens(
    direction: Float[Tensor, "dim"],
    top_k: int = 10,
) -> list[tuple[str, float]]:
    """Project a direction through the unembedding to see what tokens it promotes."""
    # Apply RMSNorm-style scaling
    normed = direction * ln_weight
    logits = unembed @ normed  # (vocab_size,)
    topk = logits.topk(top_k)
    results = []
    for idx, val in zip(topk.indices, topk.values):
        tok = tokenizer.decode([idx.item()])
        results.append((tok, val.item()))
    return results

# Quick test
print("Logit lens on mean activation:")
for tok, val in logit_lens(activations.mean(dim=0), top_k=5):
    print(f"  {val:7.2f}  '{tok}'")

## Part 2: Implement Factor Analysis

Factor Analysis (FA) is the single-component building block. The generative model is:

$$x = \mu + Wz + \epsilon$$

where:
- $z \sim \mathcal{N}(0, I_R)$ are $R$ latent factors
- $\epsilon \sim \mathcal{N}(0, \Psi)$ with diagonal $\Psi$ (per-dimension noise)
- $W \in \mathbb{R}^{D \times R}$ is the loading matrix

The covariance of $x$ is $C = WW^\top + \Psi$.

**Key insight from the paper**: The meaningful object is the **subspace** $\text{span}(W)$, not individual columns of $W$ (since $W$ is invariant to orthogonal rotations: $W' = WQ$ for orthogonal $Q$ gives the same model).

**Posterior of latent given observation**:
$$\hat{z} = (I_R + W^\top \Psi^{-1} W)^{-1} W^\top \Psi^{-1} (x - \mu)$$

In [None]:
class FactorAnalysis(nn.Module):
    """
    Factor Analysis: x = mu + W @ z + epsilon
    
    z ~ N(0, I_R), epsilon ~ N(0, diag(psi))
    Covariance: C = W @ W^T + diag(psi)
    """

    def __init__(self, d: int, rank: int):
        super().__init__()
        self.d = d
        self.rank = rank

        self.mu = nn.Parameter(torch.zeros(d))
        self.W = nn.Parameter(torch.randn(d, rank) * 0.01)
        # log_psi for positivity constraint
        self.log_psi = nn.Parameter(torch.zeros(d))

    @property
    def psi(self) -> Float[Tensor, "d"]:
        return self.log_psi.exp()

    def covariance(self) -> Float[Tensor, "d d"]:
        """C = W @ W^T + diag(psi)"""
        return self.W @ self.W.T + torch.diag(self.psi)

    def log_prob(
        self, x: Float[Tensor, "batch d"]
    ) -> Float[Tensor, "batch"]:
        """
        Log probability under this FA model: log N(x | mu, C).
        
        Use the Woodbury identity for efficient computation:
        C^{-1} = Psi^{-1} - Psi^{-1} W (I + W^T Psi^{-1} W)^{-1} W^T Psi^{-1}
        
        And the matrix determinant lemma:
        |C| = |I + W^T Psi^{-1} W| * |Psi|
        """
        # TODO: Compute the log probability efficiently using Woodbury identity
        # 
        # Steps:
        # 1. Compute psi_inv = 1 / psi  (element-wise, since Psi is diagonal)
        # 2. Compute M = I_R + W^T @ diag(psi_inv) @ W  (R x R matrix)
        # 3. Compute log_det = log|M| + sum(log_psi)  (matrix determinant lemma)
        # 4. Center data: dx = x - mu
        # 5. Compute C^{-1} @ dx using Woodbury:
        #    C_inv_dx = psi_inv * dx - psi_inv * (W @ M^{-1} @ W^T @ (psi_inv * dx))
        # 6. Mahalanobis: mahal = sum(dx * C_inv_dx, dim=-1)
        # 7. log_prob = -0.5 * (D * log(2pi) + log_det + mahal)
        
        pass

    def posterior_z(
        self, x: Float[Tensor, "batch d"]
    ) -> Float[Tensor, "batch rank"]:
        """
        Posterior mean of latent z given x.
        
        z_hat = (I_R + W^T Psi^{-1} W)^{-1} W^T Psi^{-1} (x - mu)
        """
        # TODO: Implement the posterior mean computation
        # 1. psi_inv = 1 / psi
        # 2. M = I_R + W^T @ diag(psi_inv) @ W
        # 3. dx = x - mu
        # 4. z_hat = M^{-1} @ W^T @ (psi_inv * dx)^T  -> transpose appropriately
        
        pass

In [None]:
# Test Factor Analysis
torch.manual_seed(42)
fa = FactorAnalysis(d=16, rank=3)
test_x = torch.randn(32, 16)

lp = fa.log_prob(test_x)
assert lp.shape == (32,), f"Wrong shape: {lp.shape}"
assert torch.isfinite(lp).all(), "Non-finite log probs"

z_hat = fa.posterior_z(test_x)
assert z_hat.shape == (32, 3), f"Wrong shape: {z_hat.shape}"

# Verify reconstruction: x ≈ mu + W @ z_hat (approximately)
x_recon = fa.mu + z_hat @ fa.W.T
assert x_recon.shape == test_x.shape

print("✓ Factor Analysis tests passed")
print(f"  Log prob range: [{lp.min():.1f}, {lp.max():.1f}]")
print(f"  Posterior z range: [{z_hat.min():.3f}, {z_hat.max():.3f}]")

## Part 3: Implement Mixture of Factor Analyzers

MFA extends FA with $K$ components, each having its own centroid $\mu_k$ and loading matrix $W_k$, but sharing noise $\Psi$:

$$p(x) = \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x \mid \mu_k, W_k W_k^\top + \Psi)$$

**Responsibilities** (posterior component assignment):
$$R_k(x) = \frac{\pi_k \, \mathcal{N}(x \mid \mu_k, C_k)}{\sum_j \pi_j \, \mathcal{N}(x \mid \mu_j, C_j)}$$

The paper trains via gradient descent on negative log-likelihood (not EM), initialized with K-means.

In [None]:
class MFA(nn.Module):
    """
    Mixture of Factor Analyzers.
    
    K components, each with centroid mu_k and loading W_k.
    Shared diagonal noise Psi.
    """

    def __init__(self, d: int, K: int, rank: int):
        super().__init__()
        self.d = d
        self.K = K
        self.rank = rank

        # Component parameters
        self.mus = nn.Parameter(torch.randn(K, d) * 0.1)       # centroids
        self.Ws = nn.Parameter(torch.randn(K, d, rank) * 0.01)  # loadings

        # Shared noise (log-scale for positivity)
        self.log_psi = nn.Parameter(torch.zeros(d))

        # Mixture weights (log-scale, softmax to get pi)
        self.log_pi = nn.Parameter(torch.zeros(K))

    @property
    def psi(self) -> Float[Tensor, "d"]:
        return self.log_psi.exp()

    @property
    def pi(self) -> Float[Tensor, "K"]:
        return F.softmax(self.log_pi, dim=0)

    def component_log_prob(
        self, x: Float[Tensor, "batch d"], k: int
    ) -> Float[Tensor, "batch"]:
        """
        Log probability under component k: log N(x | mu_k, W_k W_k^T + Psi).
        Uses Woodbury identity for efficiency.
        """
        # TODO: Same as FA.log_prob but using self.mus[k] and self.Ws[k]
        pass

    def log_prob(
        self, x: Float[Tensor, "batch d"]
    ) -> Float[Tensor, "batch"]:
        """
        Mixture log probability: log sum_k pi_k * N(x | mu_k, C_k)
        
        Use logsumexp for numerical stability:
        log p(x) = logsumexp_k(log pi_k + log N(x | mu_k, C_k))
        """
        # TODO: Compute log p(x) using logsumexp over components
        # 1. For each k, compute log_pi_k + component_log_prob(x, k)
        # 2. Stack into (batch, K) tensor
        # 3. logsumexp over K dimension
        
        pass

    def responsibilities(
        self, x: Float[Tensor, "batch d"]
    ) -> Float[Tensor, "batch K"]:
        """
        Posterior component probabilities: R_k(x) = p(k | x)
        
        R_k(x) = pi_k * N(x|mu_k,C_k) / sum_j pi_j * N(x|mu_j,C_j)
        
        In log space: log R_k = log pi_k + log N_k - logsumexp
        """
        # TODO: Compute responsibilities using softmax in log-space
        
        pass

    def posterior_z(
        self, x: Float[Tensor, "batch d"], k: int
    ) -> Float[Tensor, "batch rank"]:
        """
        Posterior mean of latent z for component k.
        z_hat_k = (I_R + W_k^T Psi^{-1} W_k)^{-1} W_k^T Psi^{-1} (x - mu_k)
        """
        # TODO: Implement (same formula as FA but with component-specific params)
        
        pass

    def decompose(
        self, x: Float[Tensor, "batch d"]
    ) -> dict:
        """
        Decompose activations into centroid + local offset + noise.
        
        For each x, find the most likely component k*, then:
          centroid = mu_{k*}
          local_offset = W_{k*} @ z_hat_{k*}
          noise = x - centroid - local_offset
        
        Returns dict with: assignments, centroids, local_offsets, noise, z_hat
        """
        # TODO: Implement the decomposition
        # 1. Get responsibilities -> hard assignment k* = argmax
        # 2. For each unique k*, compute posterior z and local offset
        # 3. Residual = x - mu_{k*} - W_{k*} @ z_hat
        
        pass

In [None]:
# Test MFA
torch.manual_seed(42)
test_mfa = MFA(d=16, K=4, rank=3)
test_x = torch.randn(32, 16)

lp = test_mfa.log_prob(test_x)
assert lp.shape == (32,), f"Wrong log_prob shape: {lp.shape}"

resp = test_mfa.responsibilities(test_x)
assert resp.shape == (32, 4), f"Wrong responsibilities shape: {resp.shape}"
assert torch.allclose(resp.sum(dim=-1), torch.ones(32), atol=1e-5), "Responsibilities must sum to 1"

z_hat = test_mfa.posterior_z(test_x, k=0)
assert z_hat.shape == (32, 3), f"Wrong posterior_z shape: {z_hat.shape}"

decomp = test_mfa.decompose(test_x)
assert decomp["assignments"].shape == (32,)
# Check reconstruction: centroid + local_offset + noise ≈ x
recon = decomp["centroids"] + decomp["local_offsets"] + decomp["noise"]
assert torch.allclose(recon, test_x, atol=1e-5), "Decomposition must reconstruct x"

print("✓ MFA tests passed")
print(f"  Component assignments: {decomp['assignments'].unique().tolist()}")
print(f"  Responsibilities entropy: {-(resp * resp.log()).sum(-1).mean():.3f}")

## Part 4: Train the MFA

The paper initializes centroids with K-means, then trains all parameters via gradient descent on negative log-likelihood:

$$\mathcal{L}(\theta) = -\frac{1}{B} \sum_{i=1}^{B} \log \left( \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x_i \mid \mu_k, C_k) \right)$$

We use a small $K$ here for tractability.

In [None]:
def kmeans_init(
    data: Float[Tensor, "n d"],
    K: int,
    n_iters: int = 30,
) -> Float[Tensor, "K d"]:
    """
    K-means clustering to initialize MFA centroids.
    Returns cluster centers.
    """
    n = data.shape[0]
    # Random initialization
    indices = torch.randperm(n)[:K]
    centers = data[indices].clone()

    for _ in range(n_iters):
        # Assign to nearest center
        dists = torch.cdist(data, centers)  # (n, K)
        assignments = dists.argmin(dim=-1)  # (n,)

        # Update centers
        new_centers = torch.zeros_like(centers)
        for k in range(K):
            mask = assignments == k
            if mask.sum() > 0:
                new_centers[k] = data[mask].mean(dim=0)
            else:
                new_centers[k] = data[torch.randint(n, (1,))]
        centers = new_centers

    counts = [(assignments == k).sum().item() for k in range(K)]
    print(f"K-means cluster sizes: {counts}")
    return centers

In [None]:
# Hyperparameters
K = 64          # Number of components (paper uses 1K-32K; we use fewer)
rank = 10       # Latent rank per component (same as paper)
batch_size = 512
lr = 1e-3
n_epochs = 40

# Initialize MFA with K-means centroids
torch.manual_seed(42)
print(f"Running K-means with K={K}...")
init_centers = kmeans_init(activations, K)

mfa = MFA(d=d_model, K=K, rank=rank).to(device)
with torch.no_grad():
    mfa.mus.copy_(init_centers.to(device))

n_params = sum(p.numel() for p in mfa.parameters())
print(f"\nMFA: {n_params:,} params ({K} components, rank {rank})")
print(f"  Centroids: {K} × {d_model} = {K * d_model:,}")
print(f"  Loadings:  {K} × {d_model} × {rank} = {K * d_model * rank:,}")

In [None]:
train_loader = DataLoader(
    TensorDataset(activations), batch_size=batch_size, shuffle=True, drop_last=True
)
optimizer = torch.optim.Adam(mfa.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

losses = []

for epoch in range(n_epochs):
    mfa.train()
    epoch_loss = 0.0
    n_batches = 0

    for (x_batch,) in train_loader:
        x_batch = x_batch.to(device)

        # TODO: Compute negative log-likelihood loss
        # loss = -mfa.log_prob(x_batch).mean()
        
        # TODO: Backprop and optimizer step
        
        epoch_loss += loss.item()
        n_batches += 1

    scheduler.step()
    avg_loss = epoch_loss / n_batches
    losses.append(avg_loss)

    if (epoch + 1) % 5 == 0 or epoch == 0:
        # Check component utilization
        with torch.no_grad():
            sample = activations[:2000].to(device)
            resp = mfa.responsibilities(sample)
            assignments = resp.argmax(dim=-1)
            n_used = assignments.unique().shape[0]
        print(f"Epoch {epoch+1:3d}/{n_epochs} | NLL: {avg_loss:.2f} | Components used: {n_used}/{K}")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3.5))
ax.plot(losses, linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Negative Log-Likelihood")
ax.set_title("MFA Training")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.show()

## Part 5: Interpret Components via Logit Lens

The paper's key finding: MFA components are **naturally interpretable**.

- **Centroids** ($\mu_k$) encode broad thematic regions (e.g., "genres", "emotions", "sports")
- **Loading columns** ($W_k$) encode fine-grained distinctions within a region (e.g., within "genres": fantasy vs. thriller vs. sitcom)

We inspect components by projecting centroids and loading directions through the unembedding matrix (logit lens).

In [None]:
# Find the largest components (most assigned activations)
with torch.no_grad():
    all_resp = []
    for i in range(0, activations.shape[0], 2000):
        batch = activations[i:i+2000].to(device)
        all_resp.append(mfa.responsibilities(batch).cpu())
    all_resp = torch.cat(all_resp, dim=0)
    all_assignments = all_resp.argmax(dim=-1)

# Count assignments per component
counts = torch.zeros(K)
for k in range(K):
    counts[k] = (all_assignments == k).sum()

top_components = counts.argsort(descending=True)[:10]
print("Top 10 components by size:")
for rank_idx, k in enumerate(top_components):
    k = k.item()
    print(f"\n{'='*60}")
    print(f"Component {k} ({counts[k].int().item()} activations, {100*counts[k]/len(all_assignments):.1f}%)")
    
    # Centroid: what tokens does this region promote?
    centroid = mfa.mus[k].detach().cpu()
    print(f"  Centroid top tokens: ", end="")
    for tok, val in logit_lens(centroid, top_k=8):
        print(f"'{tok.strip()}'({val:.1f})", end="  ")
    print()
    
    # Loading directions: what fine-grained distinctions?
    W_k = mfa.Ws[k].detach().cpu()  # (d, rank)
    for r in range(min(3, rank)):  # first 3 loading directions
        direction = W_k[:, r]
        # Show tokens promoted by +direction and -direction
        pos_tokens = logit_lens(direction, top_k=5)
        neg_tokens = logit_lens(-direction, top_k=5)
        pos_str = ", ".join(f"'{t.strip()}'" for t, v in pos_tokens[:4])
        neg_str = ", ".join(f"'{t.strip()}'" for t, v in neg_tokens[:4])
        print(f"  Loading {r}: [{neg_str}] ←→ [{pos_str}]")

In [None]:
# Show actual tokens assigned to a few components
print("Tokens assigned to top components:\n")
for k in top_components[:5]:
    k = k.item()
    mask = all_assignments == k
    component_tokens = token_ids[mask]
    # Sample up to 30 tokens
    sample_idx = torch.randperm(component_tokens.shape[0])[:30]
    decoded = [tokenizer.decode([t.item()]).strip() for t in component_tokens[sample_idx]]
    # Show unique tokens
    unique_toks = list(dict.fromkeys(decoded))[:20]
    print(f"Component {k} ({counts[k].int().item()} tokens):")
    print(f"  {', '.join(repr(t) for t in unique_toks)}")
    print()

## Part 6: Decompose Activations

Every activation decomposes into three parts:

$$x = \underbrace{\mu_k}_{\text{centroid}} + \underbrace{W_k \hat{z}_k}_{\text{local offset}} + \underbrace{\epsilon}_{\text{residual noise}}$$

The paper shows this decomposition has a simple 2-segment trajectory in PCA space (centroid, then local refinement), unlike SAEs which accumulate many small features.

The paper also measures **Interpretability Fraction (IF)**: what fraction of the reconstruction magnitude comes from interpretable features. MFA achieves IF ≈ 0.96 vs SAE ≈ 0.29.

In [None]:
# Decompose a batch of activations
with torch.no_grad():
    test_x = activations[:2000].to(device)
    decomp = mfa.decompose(test_x)

centroids = decomp["centroids"].cpu()
local_offsets = decomp["local_offsets"].cpu()
noise = decomp["noise"].cpu()
x_cpu = test_x.cpu()

# Verify reconstruction
recon = centroids + local_offsets + noise
assert torch.allclose(recon, x_cpu, atol=1e-4), "Decomposition must reconstruct x!"

# Measure contribution of each part (like the paper's IF metric)
centroid_norm = centroids.norm(dim=-1).mean()
offset_norm = local_offsets.norm(dim=-1).mean()
noise_norm = noise.norm(dim=-1).mean()
total_norm = x_cpu.norm(dim=-1).mean()

print("Decomposition magnitude analysis:")
print(f"  ||x||          = {total_norm:.3f}")
print(f"  ||centroid||   = {centroid_norm:.3f} ({100*centroid_norm/total_norm:.1f}%)")
print(f"  ||local_offset|| = {offset_norm:.3f} ({100*offset_norm/total_norm:.1f}%)")
print(f"  ||noise||      = {noise_norm:.3f} ({100*noise_norm/total_norm:.1f}%)")
print(f"\nInterpretable fraction (centroid + offset): {100*(centroid_norm + offset_norm)/total_norm:.1f}%")

In [None]:
# Visualize decomposition trajectories in PCA space
# Paper Figure 4: MFA has a simple 2-segment path (origin -> centroid -> x)
from sklearn.decomposition import PCA

n_show = 200
pca = PCA(n_components=2).fit(x_cpu[:1000].numpy())

origin_2d = pca.transform(np.zeros((1, d_model)))
centroids_2d = pca.transform(centroids[:n_show].numpy())
full_2d = pca.transform(x_cpu[:n_show].numpy())

fig, ax = plt.subplots(1, 1, figsize=(7, 5))

# Draw trajectories: origin -> centroid -> full activation
for i in range(n_show):
    # Segment 1: origin to centroid (broad region)
    ax.plot(
        [origin_2d[0, 0], centroids_2d[i, 0]],
        [origin_2d[0, 1], centroids_2d[i, 1]],
        c="steelblue", alpha=0.1, linewidth=0.5
    )
    # Segment 2: centroid to full (local refinement)
    ax.plot(
        [centroids_2d[i, 0], full_2d[i, 0]],
        [centroids_2d[i, 1], full_2d[i, 1]],
        c="coral", alpha=0.15, linewidth=0.5
    )

ax.scatter(*origin_2d.T, c="black", s=100, zorder=5, marker="*", label="Origin")
ax.scatter(centroids_2d[:, 0], centroids_2d[:, 1], c="steelblue", s=5, alpha=0.4, label="Centroids")
ax.scatter(full_2d[:, 0], full_2d[:, 1], c="coral", s=5, alpha=0.4, label="Activations")

ax.legend(frameon=False)
ax.set_title("MFA Decomposition Trajectories in PCA Space")
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.show()

print("Blue: origin → centroid (which region?)")
print("Red:  centroid → activation (local variation within region)")

## Part 7: Steering with MFA

The paper defines two types of interventions:

**Centroid intervention** (steer toward a region):
$$f_\mu(x) = (1 - \alpha) \cdot x + \alpha \cdot \mu_k$$

This interpolates the activation toward the centroid of component $k$.

**Local offset intervention** (refine within a region):
$$f_w(x) = x + W_k v$$

This adds a displacement along the component's loading directions.

We'll demonstrate both and show that they produce different effects.

In [None]:
def centroid_steer(
    x: Float[Tensor, "batch d"],
    centroid: Float[Tensor, "d"],
    alpha: float,
) -> Float[Tensor, "batch d"]:
    """
    Centroid intervention: interpolate toward a component's centroid.
    f(x) = (1 - alpha) * x + alpha * mu_k
    """
    # TODO: Implement centroid interpolation
    pass


def local_offset_steer(
    x: Float[Tensor, "batch d"],
    W_k: Float[Tensor, "d rank"],
    v: Float[Tensor, "rank"],
    alpha: float = 1.0,
) -> Float[Tensor, "batch d"]:
    """
    Local offset intervention: add displacement along loading directions.
    f(x) = x + alpha * W_k @ v
    """
    # TODO: Implement local offset steering
    pass

In [None]:
# Pick an interesting component to steer toward
# Use one of the top components
target_k = top_components[0].item()

centroid_k = mfa.mus[target_k].detach().cpu()
W_k = mfa.Ws[target_k].detach().cpu()

print(f"Steering toward component {target_k}")
print(f"Centroid promotes: ", end="")
for tok, val in logit_lens(centroid_k, top_k=8):
    print(f"'{tok.strip()}'({val:.1f})", end="  ")
print()

# Take a batch of activations NOT assigned to this component
other_mask = all_assignments != target_k
other_acts = activations[other_mask][:500]

# Centroid steering at different strengths
print(f"\n--- Centroid Steering (interpolate toward mu_{target_k}) ---")
alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
for alpha in alphas:
    steered = centroid_steer(other_acts, centroid_k, alpha)
    mean_steered = steered.mean(dim=0)
    top_toks = logit_lens(mean_steered, top_k=5)
    tok_str = ", ".join(f"'{t.strip()}'" for t, v in top_toks)
    print(f"  α={alpha:.1f}: {tok_str}")

# Local offset steering along first loading direction
print(f"\n--- Local Offset Steering (along loading 0 of component {target_k}) ---")
# Get activations that ARE assigned to this component
in_mask = all_assignments == target_k
in_acts = activations[in_mask][:500]

strengths = [-3.0, -1.0, 0.0, 1.0, 3.0]
for s in strengths:
    v = torch.zeros(rank)
    v[0] = s  # push along first loading direction
    steered = local_offset_steer(in_acts, W_k, v)
    mean_steered = steered.mean(dim=0)
    top_toks = logit_lens(mean_steered, top_k=5)
    tok_str = ", ".join(f"'{t.strip()}'" for t, v in top_toks)
    print(f"  v[0]={s:+.1f}: {tok_str}")

In [None]:
# Compare: centroid steering vs naive direction steering (DiffMean baseline)
# DiffMean: steer along the difference in means between component k and everything else

in_mean = activations[in_mask].mean(dim=0)
out_mean = activations[other_mask].mean(dim=0)
diff_mean_dir = in_mean - out_mean
diff_mean_dir = diff_mean_dir / diff_mean_dir.norm()

# Measure how "on-manifold" steered activations are
# by computing their log-likelihood under the MFA
test_acts = other_acts[:200].to(device)

print("Steering comparison: MFA centroid vs DiffMean direction")
print(f"{'Alpha':<8} {'Centroid NLL':>14} {'DiffMean NLL':>14}")
print("-" * 40)

for alpha in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
    with torch.no_grad():
        # Centroid steering
        steered_c = centroid_steer(test_acts, centroid_k.to(device), alpha)
        nll_c = -mfa.log_prob(steered_c).mean().item()

        # DiffMean steering (additive, scaled by alpha * distance to centroid)
        scale = alpha * (centroid_k.to(device) - test_acts.mean(dim=0)).norm()
        steered_d = test_acts + scale * diff_mean_dir.to(device)
        nll_d = -mfa.log_prob(steered_d).mean().item()

    print(f"{alpha:<8.1f} {nll_c:>14.2f} {nll_d:>14.2f}")

print("\nLower NLL = more on-manifold. Centroid steering should stay more on-manifold")
print("because it interpolates between valid activation positions.")

## Part 8: Multi-Gaussian Concept Neighborhoods

The paper finds that concepts aren't captured by single components but by **neighborhoods of nearby Gaussians**. For example, "emotions" might be a cluster of sub-Gaussians for happiness, surprise, anger, etc.

We verify this by finding which components are neighbors in centroid space.

In [None]:
# Compute pairwise distances between centroids
with torch.no_grad():
    centroids_all = mfa.mus.detach().cpu()
    centroid_dists = torch.cdist(centroids_all, centroids_all)

# For each of top-5 components, show their nearest neighbors
print("Component neighborhoods (nearest centroids in activation space):\n")
for k in top_components[:5]:
    k = k.item()
    dists_k = centroid_dists[k]
    nearest = dists_k.argsort()[1:6]  # skip self

    # Parent centroid
    parent_toks = logit_lens(centroids_all[k], top_k=5)
    parent_str = ", ".join(f"'{t.strip()}'" for t, v in parent_toks)
    print(f"Component {k}: {parent_str}")

    for neighbor in nearest:
        n = neighbor.item()
        d = dists_k[n].item()
        neighbor_toks = logit_lens(centroids_all[n], top_k=5)
        neighbor_str = ", ".join(f"'{t.strip()}'" for t, v in neighbor_toks)
        print(f"  → {n} (dist={d:.2f}): {neighbor_str}")
    print()

In [None]:
# Visualize component structure: PCA of centroids, colored by neighborhood
from sklearn.decomposition import PCA

pca_centroids = PCA(n_components=2).fit_transform(centroids_all.numpy())

# Color by cluster (use hierarchical clustering on centroids)
from scipy.cluster.hierarchy import fcluster, linkage

Z = linkage(centroids_all.numpy(), method="ward")
cluster_labels = fcluster(Z, t=8, criterion="maxclust")

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
scatter = ax.scatter(
    pca_centroids[:, 0], pca_centroids[:, 1],
    c=cluster_labels, cmap="tab10",
    s=counts.numpy() / counts.max().item() * 200 + 10,  # size by count
    alpha=0.7, edgecolors="white", linewidth=0.5
)

# Annotate top components
for k in top_components[:8]:
    k = k.item()
    top_tok = logit_lens(centroids_all[k], top_k=1)[0][0].strip()
    ax.annotate(
        f"{k}:'{top_tok}'", (pca_centroids[k, 0], pca_centroids[k, 1]),
        fontsize=7, ha="center", va="bottom",
        bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8, edgecolor="none")
    )

ax.set_title(f"MFA Centroids in PCA Space ({K} components)")
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.show()

print("Bubble size ∝ number of assigned activations.")
print("Color = hierarchical cluster of centroids.")
print("Nearby centroids with same color = concept neighborhoods.")

## Summary

You've implemented the core components of the MFA paper:

1. **Factor Analysis** — the generative model $x = \mu + Wz + \epsilon$ with Woodbury-efficient log-likelihood
2. **Mixture of Factor Analyzers** — K components with centroids, loadings, shared noise
3. **Component assignment** — responsibilities $R_k(x)$ via Bayes' rule
4. **Activation decomposition** — centroid + local offset + noise (3-part split)
5. **Logit lens interpretation** — centroids encode broad themes, loadings encode fine distinctions
6. **Two types of steering** — centroid interpolation (change region) vs local offset (refine within region)
7. **Concept neighborhoods** — related Gaussians cluster together in centroid space

### Key takeaways from the paper:

- **Beyond single directions**: Concepts have nonlinear, multi-cluster structure. MFA captures this; linear probes and SAEs don't.
- **Interpretability fraction**: MFA decompositions are ~96% interpretable (vs ~29% for SAEs) because the centroid + offset structure is inherently meaningful.
- **Competitive steering**: MFA often outperforms SAEs on causal steering benchmarks, especially for broad concepts.
- **Scaling**: The paper uses K=1K to 32K components. More components split broad Gaussians into finer sub-concepts rather than discovering entirely new structure.

### Things we didn't cover:
- RAVEL/MCQA localization benchmarks with DBM
- Full causal steering evaluation with LLM-as-judge
- Comparison at full scale (100M activations, 32K components)
- Narrow vs broad Gaussian classification
- The SAE decomposition trajectory comparison (Figure 4)