# TransformerCSSM Interpretability

This notebook analyzes the **TransformerCSSM** - a simplified, transformer-inspired variant of CSSM that makes the attention analogy explicit.

**Model Performance:** 88.50% accuracy on Pathfinder-14 (64-dim model)

**Contents:**
1. Mathematical foundations - Q/K/A dynamics
2. Comparison with HGRUBilinearCSSM
3. Sequential vs Parallel execution
4. Temporal gradient attribution
5. Mechanism attribution

**Key Insight:** TransformerCSSM makes the attention mechanism explicit:
- **Q (Query)** ≈ current representation seeking context
- **K (Key)** ≈ stored information to match against
- **A (Attention)** ≈ accumulated Q-K correlation that feeds back into Q

---
## Setup (Run First)

In [None]:
#@title Setup: Install dependencies and download checkpoint { display-mode: "form" }
import os
import sys

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install JAX with GPU support
    !pip install -q jax[cuda12] flax optax tensorflow
    
    # Clone CSSM repo
    if not os.path.exists('CSSM'):
        !git clone https://github.com/your-repo/CSSM.git
        sys.path.insert(0, 'CSSM')
    
    # Download checkpoint (--no-check-certificate for expired SSL cert)
    CHECKPOINT_URL = "https://connectomics.clps.brown.edu/tf_records/transformer_cssm_kqv64_epoch20.pkl"
    CHECKPOINT_PATH = "transformer_cssm_checkpoint.pkl"
    
    if not os.path.exists(CHECKPOINT_PATH):
        print(f"Downloading checkpoint from {CHECKPOINT_URL}...")
        !wget -q --no-check-certificate {CHECKPOINT_URL} -O {CHECKPOINT_PATH}
        print("Download complete!")
else:
    # Local paths
    CHECKPOINT_PATH = "checkpoints/KQV_64/epoch_20/checkpoint.pkl"

print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Running in {'Colab' if IN_COLAB else 'local'} environment")

---
# Part I: Mathematical Foundations
---

## 1. The TransformerCSSM Equations

### 1.1 State Variables

TransformerCSSM has three states with explicit transformer naming:

$$\mathbf{s}_t = (Q_t, K_t, A_t) \in \mathbb{R}^{H \times W \times C}$$

| State | Name | Role | Transformer Analogy |
|-------|------|------|--------------------|
| $Q_t$ | Query | Current representation | Queries seeking relevant info |
| $K_t$ | Key | Symmetric partner to Q | Keys to match against |
| $A_t$ | Attention | Accumulated Q-K history | Attention weights over time |

### 1.2 Dynamics

$$\boxed{Q_{t+1} = \underbrace{\lambda_Q \cdot Q_t}_{\text{decay}} + \underbrace{w \cdot (\mathcal{K} * K_t)}_{\text{K input via kernel}} + \underbrace{\alpha \cdot (\mathcal{K} * A_t)}_{\text{attention feedback}} + U_Q}$$

$$\boxed{K_{t+1} = \underbrace{w \cdot (\mathcal{K} * Q_t)}_{\text{Q input via kernel}} + \underbrace{\lambda_K \cdot K_t}_{\text{decay}} + U_K}$$

$$\boxed{A_{t+1} = \underbrace{\gamma \cdot Q_t}_{\text{Q accumulation}} + \underbrace{\gamma \cdot K_t}_{\text{K accumulation}} + \underbrace{\lambda_A \cdot A_t}_{\text{decay}} + U_A}$$

where:
- $\mathcal{K}$ is a **single learned spatial kernel** (not two like HGRUBilinearCSSM)
- $*$ denotes spatial convolution (via FFT)
- $w$ is the **Q↔K coupling weight** (symmetric!)
- $\alpha$ is the **attention feedback strength**
- $\gamma$ is the **attention accumulation rate**
- $\lambda_Q, \lambda_K, \lambda_A \in (0.1, 0.99)$ are decay rates

### 1.3 Matrix Form

$$\begin{bmatrix} Q \\ K \\ A \end{bmatrix}_{t+1} = 
\begin{bmatrix} 
\lambda_Q & w \cdot \mathcal{K} & \alpha \cdot \mathcal{K} \\
w \cdot \mathcal{K} & \lambda_K & 0 \\
\gamma & \gamma & \lambda_A
\end{bmatrix}
\begin{bmatrix} Q \\ K \\ A \end{bmatrix}_t + 
\begin{bmatrix} U_Q \\ U_K \\ U_A \end{bmatrix}$$

**Key observations:**
1. **Symmetric Q↔K coupling**: Both use $w \cdot \mathcal{K}$ (same weight)
2. **A is pure memory**: No kernel in A's row - just scalar accumulation
3. **A feeds back to Q only**: The $\alpha \cdot \mathcal{K} \cdot A$ term

### 1.4 The Attention Interpretation

Think of the dynamics as an **iterative attention mechanism**:

1. **Q and K interact symmetrically** via the spatial kernel $\mathcal{K}$:
   - Q receives K through the kernel (sees K's spatial context)
   - K receives Q through the kernel (symmetric)
   
2. **A accumulates Q-K correlation**:
   - $A \leftarrow \gamma (Q + K) + \lambda_A A$
   - This is like building up attention weights over time
   
3. **A feeds back into Q**:
   - $Q \leftarrow ... + \alpha \cdot \mathcal{K} * A$
   - The accumulated attention modulates Q (like applying attention weights)

**As receptive fields grow** (kernel compounds over timesteps), A accumulates Q-K correlations across increasingly large spatial regions - achieving attention-like global context integration with O(T) or O(log T) complexity instead of O(T²).

In [None]:
# Visualize the equations
print("="*75)
print("TransformerCSSM STATE UPDATE EQUATIONS")
print("="*75)
print()
print("Q_{t+1} = λ_Q·Q_t + w·(K * K_t) + α·(K * A_t) + U_Q")
print("         ────────   ───────────   ───────────")
print("         decay      K input       attention feedback")
print("                    (via kernel)  (A modulates Q!)")
print()
print("K_{t+1} = w·(K * Q_t) + λ_K·K_t + U_K")
print("         ───────────   ────────")
print("         Q input       decay")
print("         (symmetric!)")
print()
print("A_{t+1} = γ·Q_t + γ·K_t + λ_A·A_t + U_A")
print("         ──────   ──────   ────────")
print("         accumulate Q+K   decay (pure memory, NO kernel)")
print()
print("="*75)
print("K = single spatial kernel (applied via FFT convolution)")
print("w = Q↔K coupling weight (same for both directions)")
print("α = attention feedback strength (A → Q)")
print("γ = attention accumulation rate (Q,K → A)")
print("="*75)

## 2. Comparison: TransformerCSSM vs HGRUBilinearCSSM

| Aspect | HGRUBilinearCSSM | TransformerCSSM |
|--------|------------------|------------------|
| **States** | X (excit), Y (inhib), Z (interact) | Q (query), K (key), A (attention) |
| **Kernels** | 2: K_E (excit), K_I (inhib) | 1: K (shared) |
| **Coupling** | Asymmetric (α_E, α_I, μ_E, μ_I) | Symmetric (single w) |
| **Bilinear term** | X * Z (multiplicative) | α·K·A (gated feedback) |
| **Memory state** | Z has kernel: δ·(X-Y) | A is pure memory: γ·(Q+K) |
| **Gates** | ~13 | ~10 |
| **Accuracy (PF14)** | 85.50% (32-dim) | **88.50%** (64-dim) |

### Key Simplifications

1. **Single kernel** instead of E/I pair → Simpler, but loses explicit E-I dynamics
2. **Symmetric coupling** → Cleaner Q↔K relationship
3. **Pure memory A** → No spatial kernel on A, just accumulates Q+K
4. **Linear A feedback** → $\alpha \cdot K * A$ instead of $\gamma \cdot X \odot Z$ (bilinear)

The TransformerCSSM with 64 dimensions **outperforms** the 32-dim HGRUBilinearCSSM, suggesting the simpler attention-like structure scales well with increased capacity.

In [None]:
print("Matrix comparison:")
print()
print("HGRUBilinearCSSM:                    TransformerCSSM:")
print("┌                              ┐    ┌                        ┐")
print("│ λ_x + α_E·K_E + γ·Z  -α_I·K_I  0 │    │ λ_Q      w·K      α·K │")
print("│ μ_E            λ_y - μ_I   0 │    │ w·K      λ_K       0   │")
print("│ δ                  -δ     λ_z │    │ γ        γ        λ_A │")
print("└                              ┘    └                        ┘")
print()
print("Key differences:")
print("  • HGRUBi has γ·Z in diagonal (state-dependent, bilinear)")
print("  • Transformer has single kernel K (simpler)")
print("  • Transformer has symmetric w (same for Q→K and K→Q)")
print("  • Transformer A row has NO kernel (pure scalar memory)")

---
# Part II: Sequential vs Parallel Execution
---

## 3. Implementation Comparison

TransformerCSSM uses the **same associative scan** as HGRUBilinearCSSM, but with a cleaner matrix structure.

In [None]:
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from typing import Dict, Tuple

def spectral_conv_2d(x: jnp.ndarray, kernel: jnp.ndarray) -> jnp.ndarray:
    """FFT-based 2D convolution."""
    B, H, W, C = x.shape
    k = kernel.shape[1]
    
    pad_h, pad_w = (H - k) // 2, (W - k) // 2
    kernel_padded = jnp.pad(kernel, ((0, 0), (pad_h, H - k - pad_h), (pad_w, W - k - pad_w)))
    
    X_fft = jnp.fft.rfft2(x, axes=(1, 2))
    K_fft = jnp.fft.rfft2(kernel_padded, axes=(1, 2))
    K_fft = jnp.moveaxis(K_fft, 0, -1)
    
    result_fft = X_fft * K_fft[None, ...]
    return jnp.fft.irfft2(result_fft, s=(H, W), axes=(1, 2))

In [None]:
def transformer_cssm_step_sequential(
    Q: jnp.ndarray, K_state: jnp.ndarray, A: jnp.ndarray,
    U_Q: jnp.ndarray, U_K: jnp.ndarray, U_A: jnp.ndarray,
    K_kernel: jnp.ndarray,
    lambda_Q: float, lambda_K: float, lambda_A: float,
    w: float, alpha: float, gamma: float,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Single TransformerCSSM timestep (SEQUENTIAL version for interpretability).
    
    Implements:
        Q_{t+1} = λ_Q·Q + w·(K * K_state) + α·(K * A) + U_Q
        K_{t+1} = w·(K * Q) + λ_K·K_state + U_K
        A_{t+1} = γ·Q + γ·K_state + λ_A·A + U_A
    """
    # Spatial convolutions through the SINGLE kernel
    K_conv_Q = spectral_conv_2d(Q, K_kernel)        # K * Q
    K_conv_K = spectral_conv_2d(K_state, K_kernel)  # K * K_state
    K_conv_A = spectral_conv_2d(A, K_kernel)        # K * A
    
    # State updates
    Q_new = lambda_Q * Q + w * K_conv_K + alpha * K_conv_A + U_Q
    K_new = w * K_conv_Q + lambda_K * K_state + U_K
    A_new = gamma * Q + gamma * K_state + lambda_A * A + U_A
    
    return Q_new, K_new, A_new


def transformer_cssm_forward_sequential(
    U: jnp.ndarray,
    K_kernel: jnp.ndarray,
    gates: Dict[str, float],
    T: int = 8,
) -> Tuple[list, list, list]:
    """
    Full sequential forward pass.
    
    Complexity: O(T) - must process each timestep in sequence.
    """
    B, H, W, C = U.shape
    
    # Initialize states to zero
    Q = jnp.zeros((B, H, W, C))
    K_state = jnp.zeros((B, H, W, C))
    A = jnp.zeros((B, H, W, C))
    
    Q_history, K_history, A_history = [Q], [K_state], [A]
    
    for t in range(T):
        Q, K_state, A = transformer_cssm_step_sequential(
            Q, K_state, A, U, U, U, K_kernel,
            gates['lambda_Q'], gates['lambda_K'], gates['lambda_A'],
            gates['w'], gates['alpha'], gates['gamma'],
        )
        Q_history.append(Q)
        K_history.append(K_state)
        A_history.append(A)
    
    return Q_history, K_history, A_history

print("Sequential implementation defined.")
print("  - transformer_cssm_step_sequential(): Single timestep")
print("  - transformer_cssm_forward_sequential(): O(T) forward pass")

In [None]:
print("="*75)
print("SEQUENTIAL vs PARALLEL COMPUTATION")
print("="*75)
print()
print("SEQUENTIAL (for loop):")
print("  for t in range(T):")
print("      Q[t+1] = λ_Q·Q[t] + w·(K * K[t]) + α·(K * A[t]) + U")
print("      K[t+1] = w·(K * Q[t]) + λ_K·K[t] + U")
print("      A[t+1] = γ·Q[t] + γ·K[t] + λ_A·A[t] + U")
print()
print("  Time: O(T), Space: O(1)")
print()
print("-"*75)
print()
print("PARALLEL (associative scan):")
print("  # Rewrite as: s[t+1] = A_mat[t] · s[t] + b[t]")
print("  # where s = [Q, K, A]^T")
print()
print("  A_mat = [[λ_Q,  w·K,  α·K],")
print("           [w·K,  λ_K,   0 ],")
print("           [ γ,    γ,   λ_A]]")
print()
print("  # Use associative scan with composition:")
print("  # (A₂, b₂) ∘ (A₁, b₁) = (A₂·A₁, A₂·b₁ + b₂)")
print()
print("  Time: O(log T), Space: O(T)")
print()
print("="*75)

---
# Part III: Hands-On Analysis
---

## 4. Load Model and Data

In [None]:
import pickle
import matplotlib.pyplot as plt

# Load checkpoint (uses CHECKPOINT_PATH from setup cell)
with open(CHECKPOINT_PATH, 'rb') as f:
    ckpt = pickle.load(f)

params = ckpt['params']
print(f"Loaded checkpoint from epoch {ckpt.get('epoch', 'unknown')}")
print(f"CSSM params: {list(params['cssm_0'].keys())}")

In [None]:
# Extract TransformerCSSM parameters
cssm_params = params['cssm_0']

# Single spatial kernel (NOT two like HGRUBi)
K_kernel = jnp.array(cssm_params['kernel'])  # (C, k, k)
print(f"Spatial kernel shape: {K_kernel.shape}")

# Extract gate values - handle Dense layer dict structure
def get_gate(name):
    """Extract gate value, handling Dense layer dict format."""
    gate_data = cssm_params[name]
    # TransformerCSSM stores gates as Dense layers: {'kernel': ..., 'bias': ...}
    if isinstance(gate_data, dict):
        kernel = gate_data['kernel']
        bias = gate_data.get('bias', 0)
        val = kernel.mean() + (bias.mean() if hasattr(bias, 'mean') else bias)
    else:
        val = gate_data.mean()
    return float(jax.nn.sigmoid(val))

gates = {
    'lambda_Q': get_gate('decay_Q'),
    'lambda_K': get_gate('decay_K'),
    'lambda_A': get_gate('decay_A'),
    'w': get_gate('w_qk'),
    'alpha': get_gate('alpha'),
    'gamma': get_gate('gamma'),
}

print("\nExtracted gate values:")
for name, val in gates.items():
    print(f"  {name:>10}: {val:.4f}")

In [None]:
# Visualize the single kernel (64 channels for KQV_64 model)
fig, axes = plt.subplots(4, 8, figsize=(16, 8))

K_np = np.array(K_kernel)
vmax = np.abs(K_np).max()

for row in range(4):
    for col in range(8):
        idx = row * 8 + col
        if idx < K_np.shape[0]:
            axes[row, col].imshow(K_np[idx], cmap='RdBu_r', vmin=-vmax, vmax=vmax)
            axes[row, col].set_title(f'K[{idx}]', fontsize=8)
        axes[row, col].axis('off')

plt.suptitle('TransformerCSSM: Single Spatial Kernel (64 channels, shared for Q↔K and A→Q)', fontsize=14)
plt.tight_layout()
plt.show()

# Mean kernel
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(K_np.mean(axis=0), cmap='RdBu_r')
ax.set_title('Mean Kernel (averaged over 64 channels)', fontsize=12)
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
plt.show()

In [None]:
# Load model for gradient computation
from src.models.simple_cssm import SimpleCSSM

model = SimpleCSSM(
    num_classes=2,
    embed_dim=64,  # 64-dim model (KQV_64)
    depth=1,
    cssm_type='transformer',  # TransformerCSSM!
    kernel_size=15,
    pos_embed='spatiotemporal',
    seq_len=8,
)
print("TransformerCSSM model loaded (embed_dim=64).")

In [None]:
# Load sample images
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

TFRECORD_DIR = '/home/dlinsley/pathfinder_tfrecord/difficulty_14/val'

def parse_example(example):
    features = tf.io.parse_single_example(example, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    })
    image = tf.io.decode_raw(features['image'], tf.float32)
    image = tf.reshape(image, [224, 224, 3])
    return image, features['label']

val_files = sorted(tf.io.gfile.glob(f'{TFRECORD_DIR}/*.tfrecord'))
ds = tf.data.TFRecordDataset(val_files[:1]).map(parse_example)

pos_img, neg_img = None, None
for img, label in ds:
    if label.numpy() == 1 and pos_img is None:
        pos_img = img.numpy()
    elif label.numpy() == 0 and neg_img is None:
        neg_img = img.numpy()
    if pos_img is not None and neg_img is not None:
        break

print(f"Loaded positive image: {pos_img.shape}")
print(f"Loaded negative image: {neg_img.shape}")

In [None]:
# Verify predictions
def forward_single(img):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    logits = model.apply({'params': params}, x_temporal, training=False)
    return logits[0]

pos_logits = forward_single(pos_img)
neg_logits = forward_single(neg_img)

print("TransformerCSSM Predictions:")
print(f"  Positive: {pos_logits} → {'Connected' if pos_logits.argmax() == 1 else 'Disconnected'}")
print(f"  Negative: {neg_logits} → {'Connected' if neg_logits.argmax() == 1 else 'Disconnected'}")

## 5. Temporal Gradient Attribution

In [None]:
def compute_temporal_gradients(img, target_class):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    
    def forward_fn(x_t):
        logits = model.apply({'params': params}, x_t, training=False)
        return logits[0, target_class]
    
    grads = jax.grad(forward_fn)(x_temporal)
    grad_magnitude = jnp.abs(grads).sum(axis=(0, 2, 3, 4))
    spatial_grads = jnp.abs(grads[0]).sum(axis=-1)
    return grad_magnitude, spatial_grads

pos_grad_mag, pos_spatial = compute_temporal_gradients(pos_img, 1)
neg_grad_mag, neg_spatial = compute_temporal_gradients(neg_img, 0)

print("Gradient magnitude per timestep:")
print(f"  Positive: {np.array(pos_grad_mag).round(2)}")
print(f"  Negative: {np.array(neg_grad_mag).round(2)}")

In [None]:
# Plot temporal importance
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

timesteps = np.arange(8)
width = 0.35
axes[0].bar(timesteps - width/2, np.array(pos_grad_mag), width, 
           label='Positive', color='green', alpha=0.7)
axes[0].bar(timesteps + width/2, np.array(neg_grad_mag), width,
           label='Negative', color='red', alpha=0.7)
axes[0].set_xlabel('Timestep', fontsize=12)
axes[0].set_ylabel('Gradient Magnitude', fontsize=12)
axes[0].set_title('TransformerCSSM: Temporal Importance', fontsize=14)
axes[0].legend()
axes[0].set_xticks(timesteps)

# Cumulative
pos_cumsum = np.cumsum(np.array(pos_grad_mag))
neg_cumsum = np.cumsum(np.array(neg_grad_mag))
axes[1].plot(timesteps, pos_cumsum / pos_cumsum[-1], 'g-o', label='Positive', linewidth=2)
axes[1].plot(timesteps, neg_cumsum / neg_cumsum[-1], 'r-o', label='Negative', linewidth=2)
axes[1].set_xlabel('Timestep', fontsize=12)
axes[1].set_ylabel('Cumulative Importance', fontsize=12)
axes[1].set_title('Information Integration Over Time', fontsize=14)
axes[1].legend()
axes[1].set_xticks(timesteps)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Mechanism Attribution: Which Gates Matter?

In [None]:
def parameter_gradients(img, target_class):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    
    def loss_fn(p):
        logits = model.apply({'params': p}, x_temporal, training=False)
        return logits[0, target_class]
    
    return jax.grad(loss_fn)(params)

pos_param_grads = parameter_gradients(pos_img, 1)
neg_param_grads = parameter_gradients(neg_img, 0)

In [None]:
# TransformerCSSM-specific gates
mechanisms = [
    ('w_qk', 'w (Q↔K coupling)', 'Symmetric Q-K interaction'),
    ('alpha', 'α (A→Q feedback)', 'Attention feeds back to Q'),
    ('gamma', 'γ (Q,K→A accum)', 'Attention accumulation rate'),
    ('decay_Q', 'λ_Q (Q memory)', 'Query state persistence'),
    ('decay_K', 'λ_K (K memory)', 'Key state persistence'),
    ('decay_A', 'λ_A (A memory)', 'Attention memory persistence'),
]

print("="*80)
print("TransformerCSSM MECHANISM ATTRIBUTION")
print("="*80)
print(f"{'Mechanism':<25} {'Description':<30} {'Pos':>10} {'Neg':>10}")
print("-"*80)

for key, name, desc in mechanisms:
    if key in pos_param_grads['cssm_0']:
        pos_mag = np.abs(np.array(pos_param_grads['cssm_0'][key])).mean()
        neg_mag = np.abs(np.array(neg_param_grads['cssm_0'][key])).mean()
        print(f"{name:<25} {desc:<30} {pos_mag:>10.6f} {neg_mag:>10.6f}")

In [None]:
# Kernel gradient
K_grad_pos = np.array(pos_param_grads['cssm_0']['kernel'])
K_grad_neg = np.array(neg_param_grads['cssm_0']['kernel'])

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

im0 = axes[0].imshow(K_grad_pos.mean(axis=0), cmap='RdBu_r')
axes[0].set_title('Kernel grad (Positive)', fontsize=12)
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], fraction=0.046)

im1 = axes[1].imshow(K_grad_neg.mean(axis=0), cmap='RdBu_r')
axes[1].set_title('Kernel grad (Negative)', fontsize=12)
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046)

diff = K_grad_pos.mean(axis=0) - K_grad_neg.mean(axis=0)
im2 = axes[2].imshow(diff, cmap='RdBu_r')
axes[2].set_title('Difference (Pos - Neg)', fontsize=12)
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], fraction=0.046)

plt.suptitle('Kernel Gradients: How Should K Change for Each Decision?', fontsize=14)
plt.tight_layout()
plt.show()

---
# Part IV: Summary
---

## Key Findings

### 1. TransformerCSSM Equations

$$Q_{t+1} = \lambda_Q Q_t + w (\mathcal{K} * K_t) + \alpha (\mathcal{K} * A_t) + U_Q$$
$$K_{t+1} = w (\mathcal{K} * Q_t) + \lambda_K K_t + U_K$$
$$A_{t+1} = \gamma Q_t + \gamma K_t + \lambda_A A_t + U_A$$

### 2. Comparison with HGRUBilinearCSSM

| | HGRUBilinearCSSM (32-dim) | TransformerCSSM (64-dim) |
|--|---------------------------|--------------------------|
| States | X, Y, Z | Q, K, A |
| Kernels | 2 (K_E, K_I) | 1 (shared) |
| Coupling | Asymmetric | Symmetric |
| Gates | ~13 | ~10 |
| Accuracy | 85.50% | **88.50%** |

### 3. The Attention Mechanism

TransformerCSSM makes the attention analogy explicit:
- **Q** seeks relevant context (like queries)
- **K** provides context to match (like keys)
- **A** accumulates Q-K correlation (like attention weights)
- **A feeds back to Q** (like applying attention)

### 4. Why It Works

Despite being simpler than HGRUBilinearCSSM, TransformerCSSM achieves **better** accuracy (88.50% vs 85.50%) because:
1. The **single kernel** is sufficient for contour spreading
2. **Symmetric Q↔K** captures the essential bidirectional information flow
3. **A as pure memory** (no kernel) still accumulates the necessary context
4. The **attention feedback** (α·K·A → Q) provides the key modulation
5. **64-dim embedding** provides more representational capacity

In [None]:
print("="*70)
print("NOTEBOOK COMPLETE")
print("="*70)
print()
print("TransformerCSSM Key Insights:")
print("  1. Achieves 88.50% accuracy (beats HGRUBilinearCSSM's 85.50%!)")
print("  2. Uses explicit Q/K/A naming (transformer-like)")
print("  3. Single kernel instead of E/I pair")
print("  4. Symmetric Q↔K coupling via shared weight w")
print("  5. A is pure memory that feeds back into Q (attention mechanism)")
print("  6. 64-dim embedding provides more capacity than 32-dim")