# TransformerCSSM: Iterative Attention for Visual Reasoning

This notebook introduces **TransformerCSSM**, a model that combines the parallelizability of state space models with the expressiveness of transformer-style attention.

**Model Performance:** 88.50% accuracy on Pathfinder-14

---

## Table of Contents

1. **What is CSSM?** - Convolutions in spectral domain, log-space stability
2. **The Transformer's Trick** - Instant attention at all positions
3. **When Iterative Attention Helps** - The Pathfinder task
4. **CSSM as a Solution** - Parallel RNNs, but limited expressiveness
5. **Insights from the hGRU** - Bilinear terms and growing receptive fields
6. **TransformerCSSM** - Bringing it all together
7. **Hands-On Analysis** - Gradients, mechanisms, and interpretability

In [None]:
#@title Setup: Install dependencies and download checkpoint { display-mode: "form" }
import os
import sys

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install JAX with GPU support
    !pip install -q jax[cuda12] flax optax tensorflow
    
    # Clone CSSM repo
    if not os.path.exists('CSSM'):
        !git clone https://github.com/your-repo/CSSM.git
        sys.path.insert(0, 'CSSM')
    
    # Download checkpoint (--no-check-certificate for expired SSL cert)
    CHECKPOINT_URL = "https://connectomics.clps.brown.edu/tf_records/transformer_cssm_kqv64_epoch20.pkl"
    CHECKPOINT_PATH = "transformer_cssm_checkpoint.pkl"
    
    if not os.path.exists(CHECKPOINT_PATH):
        print(f"Downloading checkpoint from {CHECKPOINT_URL}...")
        !wget -q --no-check-certificate {CHECKPOINT_URL} -O {CHECKPOINT_PATH}
        print("Download complete!")
else:
    # Local paths
    CHECKPOINT_PATH = "checkpoints/KQV_64/epoch_20/checkpoint.pkl"

print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Running in {'Colab' if IN_COLAB else 'local'} environment")

---
# Part I: Background
---

## 1. What is CSSM?

**CSSM (Cepstral State Space Model)** is a recurrent neural network designed for efficient visual processing. Think of it as an RNN that can run in parallel.

### The Core Idea

A standard RNN updates its hidden state like this:

$$h_{t+1} = A \cdot h_t + B \cdot x_t$$

The problem? Each timestep depends on the previous one, so you must compute them **sequentially**. For T timesteps, that's O(T) serial operations—slow on GPUs that thrive on parallelism.

### Why Spectral Domain?

CSSM performs spatial convolutions using the **Fast Fourier Transform (FFT)**:

$$\text{Conv}(K, X) = \mathcal{F}^{-1}\left(\mathcal{F}(K) \odot \mathcal{F}(X)\right)$$

**In plain English:** Instead of sliding a kernel across an image (expensive), we:
1. Transform both kernel and image to frequency domain (FFT)
2. Multiply them element-wise (cheap!)
3. Transform back (inverse FFT)

This reduces spatial convolution from O(N²) to O(N log N).

### Why Log-Space (GOOM)?

When you multiply many numbers together over time (like decay rates), values can explode or vanish:

$$h_T = \lambda^T \cdot h_0 \quad \text{(exponential growth/decay)}$$

**GOOM (Generalized Order of Magnitude)** solves this by working in log-space:

$$\log(a \cdot b) = \log(a) + \log(b)$$

Multiplications become additions, keeping values numerically stable even over hundreds of timesteps.

## 2. The Transformer's Trick: Instant Attention

Transformers revolutionized deep learning with **self-attention**:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$

### What This Means

Every position in the input can instantly "attend to" every other position:
- **Q (Query)**: "What am I looking for?"
- **K (Key)**: "What do I contain?"
- **V (Value)**: "What information do I carry?"

The $QK^T$ product computes similarity between all pairs of positions **in one shot**.

### The Good

- **Fully parallelizable**: No sequential dependencies
- **Global context**: Any position can see any other position
- **Scales well**: GPUs love matrix multiplications

### The Catch

- **O(N²) complexity**: Comparing all pairs of N positions is expensive
- **All-or-nothing**: Attention is computed once, no refinement
- **No iteration**: Can't "change your mind" or backtrack

## 3. When Iterative Attention Helps: The Pathfinder Task

Some visual tasks require **iterative reasoning**—you can't solve them in one glance.

### The Pathfinder Challenge

The task: **Are the two dots connected by a continuous curve?**

The image contains:
- Two marker dots (endpoints)
- A potential connecting contour (curved path)
- Distractor curves (noise to confuse you)

**Why is this hard?**
- The contour can be long and winding
- Distractors look similar to the real path
- You need to "trace" the curve step by step

### Why Iterative Attention?

Imagine tracing a contour with your finger:

1. **Start at one dot**
2. **Follow the curve locally** (look at nearby pixels)
3. **Keep extending** your search
4. **Backtrack if stuck** (hit a dead end? try another direction)
5. **Succeed when you reach the other dot**

This is fundamentally **iterative**—you make local decisions that accumulate into a global answer.

### The Transformer Problem

A standard transformer computes attention once and makes a decision. But for Pathfinder:
- **Early layers** can only see local structure
- **To trace a long curve**, you need many layers stacked (deep = slow)
- **No backtracking**: if an early layer makes a mistake, later layers can't fix it

### The RNN Solution (and its flaw)

An RNN can iterate naturally:
- Each timestep refines the answer
- Information spreads gradually across the image
- More timesteps = larger effective receptive field

**But**: Traditional RNNs are sequential → O(T) time → slow on GPUs.

In [None]:
# Let's visualize the Pathfinder task
import matplotlib.pyplot as plt
import numpy as np

# We'll load actual examples later, but here's what the task looks like:
print("="*60)
print("THE PATHFINDER TASK")
print("="*60)
print()
print("  ┌─────────────────────────────────┐")
print("  │     ●                           │  ● = marker dots")
print("  │      ╲                          │")
print("  │       ╲    ╱╲                   │  The question:")
print("  │        ╲  ╱  ╲                  │  Are the two dots")
print("  │         ╲╱    ╲                 │  connected by a")
print("  │                 ╲    ╱╲         │  continuous curve?")
print("  │                  ╲  ╱  ╲        │")
print("  │    ╱╲             ╲╱    ●       │")
print("  │   ╱  ╲                          │  (ignore distractors)")
print("  │  ╱    ╲                         │")
print("  └─────────────────────────────────┘")
print()
print("Connected (positive): The curve links both dots")
print("Disconnected (negative): Dots are on separate curves")
print()
print("Difficulty levels:")
print("  • Level 9:  Short contours (easy)")
print("  • Level 14: Medium contours (this notebook)")
print("  • Level 20: Long, winding contours (hard)")
print("="*60)

## 4. CSSM: A Parallel RNN (But Limited)

CSSM offers a potential solution: **an RNN that runs in parallel**.

### The Associative Scan Trick

For linear recurrences of the form:

$$h_{t+1} = A \cdot h_t + b_t$$

We can use the **associative scan** to compute all timesteps in O(log T) parallel time instead of O(T) sequential time.

**How?** The operation $(A_2, b_2) \circ (A_1, b_1) = (A_2 A_1, A_2 b_1 + b_2)$ is associative, so we can reorganize the computation into a tree:

```
t=0    t=1    t=2    t=3    t=4    t=5    t=6    t=7
  \    /        \    /        \    /        \    /
   [0:1]         [2:3]         [4:5]         [6:7]      ← Level 1
      \          /                \          /
       [0:3]                       [4:7]                ← Level 2
           \                      /
                  [0:7]                                 ← Level 3
```

Instead of 8 sequential steps, we need only 3 parallel levels.

### The Limitation

Basic CSSM computes attention as a function of a **single state variable**:

$$h_{t+1} = \lambda \cdot h_t + \text{input}$$

This is like a transformer with only **one** of Q, K, or V—severely limited expressiveness.

**Can we make a CSSM that benefits from the Query-Key interaction of transformers?**

## 5. Insights from the hGRU

A key hint came from the **hGRU (horizontal Gated Recurrent Unit)**—an RNN we developed that successfully solves Pathfinder.

### The hGRU's Secret: Bilinear Interactions

The hGRU has two interacting cell populations:
- **Excitatory cells (X)**: Spread activation along contours
- **Inhibitory cells (Y)**: Suppress distractors

The critical insight is the **bilinear term**:

$$Y_{t+1} \propto X_t \odot Y_t$$

**In plain English:** Inhibition is computed as the **product** of excitatory and inhibitory activity. This multiplicative interaction is much more expressive than simple addition.

### Growing Receptive Fields = Growing Attention

Each timestep, information spreads through a spatial kernel. After T timesteps:

- **t=1**: Each pixel sees its immediate neighbors
- **t=4**: Each pixel sees a moderate neighborhood  
- **t=8**: Each pixel sees a large region

The kernel **compounds over time**, creating an effective receptive field that grows with each iteration.

**This is like attention with a growing radius!**

At early timesteps, comparisons are local. At later timesteps, comparisons span the entire image. This allows the network to:
1. **Start with local edge detection**
2. **Gradually integrate** into longer contours
3. **Make global decisions** only when enough context is gathered

### The Challenge

The hGRU works great, but it's a traditional RNN—**sequential and slow**.

**Can we preserve these bilinear, growing-receptive-field dynamics in a parallelizable CSSM?**

---
# Part II: TransformerCSSM
---

## 6. TransformerCSSM: Bringing It Together

**TransformerCSSM** is our attempt to combine:
- ✅ **Parallel computation** (from CSSM's associative scan)
- ✅ **Query-Key interactions** (from transformers)
- ✅ **Growing receptive fields** (from hGRU)
- ✅ **Iterative refinement** (from RNNs)

### The Three States

We use transformer-inspired naming for three interacting state variables:

| State | Name | Role | Intuition |
|-------|------|------|-----------|
| **Q** | Query | "What am I looking for?" | Current representation seeking context |
| **K** | Key | "What do I contain?" | Information available to match against |
| **A** | Attention | "What have I found?" | Accumulated Q-K correlations over time |

### The Update Equations (ELI5 Version)

**Query Update:**
$$Q_{t+1} = \underbrace{\lambda_Q \cdot Q_t}_{\text{remember old Q}} + \underbrace{w \cdot (\mathcal{K} * K_t)}_{\text{look at K through kernel}} + \underbrace{\alpha \cdot (\mathcal{K} * A_t)}_{\text{attention feedback}} + \underbrace{U_Q}_{\text{new input}}$$

*"The Query remembers itself, looks at what the Key contains (through a spatial kernel), gets modulated by accumulated Attention, and receives new input."*

**Key Update:**
$$K_{t+1} = \underbrace{w \cdot (\mathcal{K} * Q_t)}_{\text{look at Q through kernel}} + \underbrace{\lambda_K \cdot K_t}_{\text{remember old K}} + \underbrace{U_K}_{\text{new input}}$$

*"The Key looks at what the Query is seeking (symmetric to above!) and remembers itself."*

**Attention Accumulator:**
$$A_{t+1} = \underbrace{\gamma \cdot (Q_t + K_t)}_{\text{accumulate Q-K activity}} + \underbrace{\lambda_A \cdot A_t}_{\text{remember old A}} + \underbrace{U_A}_{\text{new input}}$$

*"Attention accumulates the sum of Q and K over time—building up a record of where Q and K agreed."*

### Why This Works

1. **Q and K interact symmetrically** through a shared spatial kernel $\mathcal{K}$
2. **The kernel grows effective receptive field** over timesteps (like hGRU)
3. **A accumulates Q-K correlation** over time (like building attention weights)
4. **A feeds back into Q** (the accumulated attention modulates future queries)
5. **Everything is linear** → can use associative scan → **parallel!**

In [None]:
# Matrix form visualization
print("="*70)
print("TransformerCSSM as a 3x3 State Transition Matrix")
print("="*70)
print()
print("  ┌                              ┐   ┌   ┐     ┌     ┐")
print("  │  λ_Q      w·K      α·K       │   │ Q │     │ U_Q │")
print("  │  w·K      λ_K       0        │ × │ K │  +  │ U_K │")
print("  │   γ        γ       λ_A       │   │ A │     │ U_A │")
print("  └                              ┘   └   ┘     └     ┘")
print()
print("Where:")
print("  • λ_Q, λ_K, λ_A = decay rates (memory)")
print("  • w = Q↔K coupling weight (symmetric!)")
print("  • α = attention feedback strength (A → Q)")
print("  • γ = attention accumulation rate")
print("  • K = spatial convolution kernel (via FFT)")
print()
print("Key insight: Q↔K coupling is SYMMETRIC (same w in both directions)")
print("This is like Q and K 'talking to each other' through the same channel")
print("="*70)

### Sequential vs Parallel: The Best of Both Worlds

**Sequential (traditional RNN):**
```
for t in range(T):
    Q[t+1] = λ_Q·Q[t] + w·(K * K[t]) + α·(K * A[t]) + U
    K[t+1] = w·(K * Q[t]) + λ_K·K[t] + U  
    A[t+1] = γ·(Q[t] + K[t]) + λ_A·A[t] + U
```
Time: **O(T)** sequential steps

**Parallel (associative scan):**
```
# Compute all timesteps simultaneously using tree reduction
states = associative_scan(combine_fn, inputs)
```
Time: **O(log T)** parallel steps

For T=8 timesteps: sequential needs 8 steps, parallel needs only 3!

---
# Part III: Hands-On Analysis
---

Now let's load a trained TransformerCSSM and see how it solves Pathfinder.

## 7. Loading the Model

We'll load a TransformerCSSM trained on Pathfinder-14 (88.50% accuracy) and analyze:
1. What the learned spatial kernel looks like
2. How the model makes decisions over time
3. Which mechanisms matter most

In [None]:
import pickle
import matplotlib.pyplot as plt

# Load checkpoint (uses CHECKPOINT_PATH from setup cell)
with open(CHECKPOINT_PATH, 'rb') as f:
    ckpt = pickle.load(f)

params = ckpt['params']
print(f"Loaded checkpoint from epoch {ckpt.get('epoch', 'unknown')}")
print(f"CSSM params: {list(params['cssm_0'].keys())}")

In [None]:
# Extract TransformerCSSM parameters
cssm_params = params['cssm_0']

# Single spatial kernel (NOT two like HGRUBi)
K_kernel = jnp.array(cssm_params['kernel'])  # (C, k, k)
print(f"Spatial kernel shape: {K_kernel.shape}")

# Extract gate values - handle Dense layer dict structure
def get_gate(name):
    """Extract gate value, handling Dense layer dict format."""
    gate_data = cssm_params[name]
    # TransformerCSSM stores gates as Dense layers: {'kernel': ..., 'bias': ...}
    if isinstance(gate_data, dict):
        kernel = gate_data['kernel']
        bias = gate_data.get('bias', 0)
        val = kernel.mean() + (bias.mean() if hasattr(bias, 'mean') else bias)
    else:
        val = gate_data.mean()
    return float(jax.nn.sigmoid(val))

gates = {
    'lambda_Q': get_gate('decay_Q'),
    'lambda_K': get_gate('decay_K'),
    'lambda_A': get_gate('decay_A'),
    'w': get_gate('w_qk'),
    'alpha': get_gate('alpha'),
    'gamma': get_gate('gamma'),
}

print("\nExtracted gate values:")
for name, val in gates.items():
    print(f"  {name:>10}: {val:.4f}")

In [None]:
# Visualize the single kernel (64 channels for KQV_64 model)
fig, axes = plt.subplots(4, 8, figsize=(16, 8))

K_np = np.array(K_kernel)
vmax = np.abs(K_np).max()

for row in range(4):
    for col in range(8):
        idx = row * 8 + col
        if idx < K_np.shape[0]:
            axes[row, col].imshow(K_np[idx], cmap='RdBu_r', vmin=-vmax, vmax=vmax)
            axes[row, col].set_title(f'K[{idx}]', fontsize=8)
        axes[row, col].axis('off')

plt.suptitle('TransformerCSSM: Single Spatial Kernel (64 channels, shared for Q↔K and A→Q)', fontsize=14)
plt.tight_layout()
plt.show()

# Mean kernel
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(K_np.mean(axis=0), cmap='RdBu_r')
ax.set_title('Mean Kernel (averaged over 64 channels)', fontsize=12)
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
plt.show()

In [None]:
# Load model for gradient computation
from src.models.simple_cssm import SimpleCSSM

model = SimpleCSSM(
    num_classes=2,
    embed_dim=64,  # 64-dim model (KQV_64)
    depth=1,
    cssm_type='transformer',  # TransformerCSSM!
    kernel_size=15,
    pos_embed='spatiotemporal',
    seq_len=8,
)
print("TransformerCSSM model loaded (embed_dim=64).")

In [None]:
# Load sample Pathfinder images
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

# For Colab, we'll use a sample image URL; locally, load from TFRecords
if IN_COLAB:
    # Download sample images for Colab
    import urllib.request
    SAMPLE_URL = "https://connectomics.clps.brown.edu/tf_records/pathfinder_samples.npz"
    try:
        urllib.request.urlretrieve(SAMPLE_URL, "pathfinder_samples.npz")
        data = np.load("pathfinder_samples.npz")
        pos_img, neg_img = data['pos'], data['neg']
    except:
        # Fallback: create dummy images for demo
        print("Could not load samples, using random images for demo")
        pos_img = np.random.rand(224, 224, 3).astype(np.float32)
        neg_img = np.random.rand(224, 224, 3).astype(np.float32)
else:
    TFRECORD_DIR = '/home/dlinsley/pathfinder_tfrecord/difficulty_14/val'
    
    def parse_example(example):
        features = tf.io.parse_single_example(example, {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        })
        image = tf.io.decode_raw(features['image'], tf.float32)
        image = tf.reshape(image, [224, 224, 3])
        return image, features['label']
    
    val_files = sorted(tf.io.gfile.glob(f'{TFRECORD_DIR}/*.tfrecord'))
    ds = tf.data.TFRecordDataset(val_files[:1]).map(parse_example)
    
    pos_img, neg_img = None, None
    for img, label in ds:
        if label.numpy() == 1 and pos_img is None:
            pos_img = img.numpy()
        elif label.numpy() == 0 and neg_img is None:
            neg_img = img.numpy()
        if pos_img is not None and neg_img is not None:
            break

# Visualize the examples
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].imshow(pos_img)
axes[0].set_title('CONNECTED (Positive)\nThe two dots ARE linked', fontsize=12)
axes[0].axis('off')

axes[1].imshow(neg_img)
axes[1].set_title('DISCONNECTED (Negative)\nThe two dots are NOT linked', fontsize=12)
axes[1].axis('off')

plt.suptitle('Pathfinder-14 Examples: Can You Trace the Contour?', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Image shape: {pos_img.shape}")

In [None]:
# Verify predictions
def forward_single(img):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    logits = model.apply({'params': params}, x_temporal, training=False)
    return logits[0]

pos_logits = forward_single(pos_img)
neg_logits = forward_single(neg_img)

print("TransformerCSSM Predictions:")
print(f"  Positive: {pos_logits} → {'Connected' if pos_logits.argmax() == 1 else 'Disconnected'}")
print(f"  Negative: {neg_logits} → {'Connected' if neg_logits.argmax() == 1 else 'Disconnected'}")

## 8. Temporal Gradient Attribution

How does the model's decision depend on each timestep? We compute gradients of the output with respect to the input at each of the 8 timesteps.

In [None]:
def compute_temporal_gradients(img, target_class):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    
    def forward_fn(x_t):
        logits = model.apply({'params': params}, x_t, training=False)
        return logits[0, target_class]
    
    grads = jax.grad(forward_fn)(x_temporal)
    grad_magnitude = jnp.abs(grads).sum(axis=(0, 2, 3, 4))
    spatial_grads = jnp.abs(grads[0]).sum(axis=-1)
    return grad_magnitude, spatial_grads

pos_grad_mag, pos_spatial = compute_temporal_gradients(pos_img, 1)
neg_grad_mag, neg_spatial = compute_temporal_gradients(neg_img, 0)

print("Gradient magnitude per timestep:")
print(f"  Positive: {np.array(pos_grad_mag).round(2)}")
print(f"  Negative: {np.array(neg_grad_mag).round(2)}")

In [None]:
# Plot temporal importance
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

timesteps = np.arange(8)
width = 0.35
axes[0].bar(timesteps - width/2, np.array(pos_grad_mag), width, 
           label='Positive', color='green', alpha=0.7)
axes[0].bar(timesteps + width/2, np.array(neg_grad_mag), width,
           label='Negative', color='red', alpha=0.7)
axes[0].set_xlabel('Timestep', fontsize=12)
axes[0].set_ylabel('Gradient Magnitude', fontsize=12)
axes[0].set_title('TransformerCSSM: Temporal Importance', fontsize=14)
axes[0].legend()
axes[0].set_xticks(timesteps)

# Cumulative
pos_cumsum = np.cumsum(np.array(pos_grad_mag))
neg_cumsum = np.cumsum(np.array(neg_grad_mag))
axes[1].plot(timesteps, pos_cumsum / pos_cumsum[-1], 'g-o', label='Positive', linewidth=2)
axes[1].plot(timesteps, neg_cumsum / neg_cumsum[-1], 'r-o', label='Negative', linewidth=2)
axes[1].set_xlabel('Timestep', fontsize=12)
axes[1].set_ylabel('Cumulative Importance', fontsize=12)
axes[1].set_title('Information Integration Over Time', fontsize=14)
axes[1].legend()
axes[1].set_xticks(timesteps)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Mechanism Attribution: Which Components Matter?

Which parts of the TransformerCSSM are most important for the decision? We compute gradients with respect to each learned parameter.

In [None]:
def parameter_gradients(img, target_class):
    x = jnp.array(img)[None, ...]
    x_temporal = jnp.repeat(x[:, None, ...], 8, axis=1)
    
    def loss_fn(p):
        logits = model.apply({'params': p}, x_temporal, training=False)
        return logits[0, target_class]
    
    return jax.grad(loss_fn)(params)

pos_param_grads = parameter_gradients(pos_img, 1)
neg_param_grads = parameter_gradients(neg_img, 0)

In [None]:
# TransformerCSSM-specific gates
mechanisms = [
    ('w_qk', 'w (Q↔K coupling)', 'Symmetric Q-K interaction'),
    ('alpha', 'α (A→Q feedback)', 'Attention feeds back to Q'),
    ('gamma', 'γ (Q,K→A accum)', 'Attention accumulation rate'),
    ('decay_Q', 'λ_Q (Q memory)', 'Query state persistence'),
    ('decay_K', 'λ_K (K memory)', 'Key state persistence'),
    ('decay_A', 'λ_A (A memory)', 'Attention memory persistence'),
]

print("="*80)
print("TransformerCSSM MECHANISM ATTRIBUTION")
print("="*80)
print(f"{'Mechanism':<25} {'Description':<30} {'Pos':>10} {'Neg':>10}")
print("-"*80)

for key, name, desc in mechanisms:
    if key in pos_param_grads['cssm_0']:
        pos_mag = np.abs(np.array(pos_param_grads['cssm_0'][key])).mean()
        neg_mag = np.abs(np.array(neg_param_grads['cssm_0'][key])).mean()
        print(f"{name:<25} {desc:<30} {pos_mag:>10.6f} {neg_mag:>10.6f}")

In [None]:
# Kernel gradient
K_grad_pos = np.array(pos_param_grads['cssm_0']['kernel'])
K_grad_neg = np.array(neg_param_grads['cssm_0']['kernel'])

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

im0 = axes[0].imshow(K_grad_pos.mean(axis=0), cmap='RdBu_r')
axes[0].set_title('Kernel grad (Positive)', fontsize=12)
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], fraction=0.046)

im1 = axes[1].imshow(K_grad_neg.mean(axis=0), cmap='RdBu_r')
axes[1].set_title('Kernel grad (Negative)', fontsize=12)
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046)

diff = K_grad_pos.mean(axis=0) - K_grad_neg.mean(axis=0)
im2 = axes[2].imshow(diff, cmap='RdBu_r')
axes[2].set_title('Difference (Pos - Neg)', fontsize=12)
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], fraction=0.046)

plt.suptitle('Kernel Gradients: How Should K Change for Each Decision?', fontsize=14)
plt.tight_layout()
plt.show()

---
# Part IV: Summary
---

## What We've Learned

### The Problem
- **Transformers** compute attention instantly but can't iterate or backtrack
- **RNNs** can iterate but are sequential (slow on GPUs)
- **Tasks like Pathfinder** require iterative, growing-receptive-field reasoning

### The Solution: TransformerCSSM

| Component | Inspiration | Benefit |
|-----------|-------------|---------|
| Q-K interaction | Transformers | Expressive attention-like computation |
| Growing receptive field | hGRU | Local → global reasoning over time |
| Associative scan | State Space Models | O(log T) parallel computation |
| Attention accumulator (A) | Novel | Memory of Q-K correlations |

### The Equations (Recap)

$$Q_{t+1} = \lambda_Q Q_t + w (\mathcal{K} * K_t) + \alpha (\mathcal{K} * A_t) + U_Q$$
$$K_{t+1} = w (\mathcal{K} * Q_t) + \lambda_K K_t + U_K$$
$$A_{t+1} = \gamma (Q_t + K_t) + \lambda_A A_t + U_A$$

### Results

| Model | Architecture | Pathfinder-14 Accuracy |
|-------|--------------|------------------------|
| Standard Transformer | 12 layers, attention | ~75% |
| hGRU (sequential RNN) | Bilinear E-I dynamics | ~90% |
| **TransformerCSSM** | Q-K-A with parallel scan | **88.50%** |

TransformerCSSM achieves near-hGRU performance while being **parallelizable**—the best of both worlds.

In [None]:
print("="*70)
print("KEY TAKEAWAYS")
print("="*70)
print()
print("1. CSSM = RNN that runs in parallel via associative scan")
print()
print("2. Basic CSSM is limited (single state variable)")
print()
print("3. TransformerCSSM adds Q-K-A dynamics inspired by:")
print("   • Transformer attention (Q-K interaction)")
print("   • hGRU (bilinear terms, growing receptive fields)")
print()
print("4. The model achieves 88.50% on Pathfinder-14")
print("   (comparable to sequential hGRU, but parallelizable!)")
print()
print("5. Key mechanisms:")
print("   • Symmetric Q↔K coupling through spatial kernel")
print("   • A accumulates Q-K history (attention memory)")
print("   • A feeds back into Q (attention modulation)")
print()
print("="*70)