# Gated Sparse Attention (GSA) — Step-by-Step Explanation

**Reference:** [alfredcs/Gated-Sparse-Attention](https://github.com/alfredcs/Gated-Sparse-Attention)

This notebook walks through the entire GSA forward pass using a **tiny example** so you can inspect every tensor at every step.

## Architecture Overview

```
Input x [B, T, D]
  │
  ├──► q_proj ──► reshape [B,T,H,d] ──► RoPE ──────────────────────────┐
  ├──► k_proj ──► reshape [B,T,H,d] ──► RoPE ──────────────────────────┤
  ├──► v_proj ──► reshape [B,T,H,d] ──► ValueGate(G2) ────────────────┤
  │                                                                     │
  ├──► GatedLightningIndexer ──► scores [B,T,T] ──┐                    │
  │                                                 │                    │
  │                              AdaptiveTopKSelector│                   │
  │                              indices [B,T,k] ───┼──► Sparse ────────┤
  │                              mask [B,T,k] ──────┘    Attention      │
  │                                                       │             │
  │                                    attn_output ◄──────┘             │
  │                                        │                            │
  ├──► OutputGate(G1) ─────────────────────┘                            │
  │        │                                                            │
  │   o_proj ──► Output                                                 │
```

### Key Idea: Sparse Attention via Learned Indexing

Instead of attending to ALL tokens (O(T²)), GSA:
1. Uses a **lightweight indexer** to score all token pairs → O(T²) with tiny d_indexer
2. Selects **top-k most important** tokens per query → reduces to O(T×k)
3. Performs **full attention** only on the selected subset
4. Uses **dual gating** (value + output gates) for fine-grained control

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)
torch.set_grad_enabled(False)

# Tiny example dimensions
B = 1        # batch size
T = 8        # sequence length (small enough to see full T×T matrices)
D = 32       # model dimension
H = 2        # attention heads
d_head = D // H  # 16 per head

# Indexer dimensions
d_indexer = 8      # indexer dimension (ref uses 64, we use 8 for visibility)
n_idx_heads = 2    # indexer heads (ref uses 4)
k_select = 4       # number of tokens to attend to (ref uses adaptive ~2048)

print(f"Model: D={D}, H={H}, d_head={d_head}")
print(f"Indexer: d_indexer={d_indexer}, n_idx_heads={n_idx_heads}")
print(f"Sparse selection: k={k_select} out of T={T} tokens")
print(f"\nSparsity ratio: {k_select}/{T} = {k_select/T:.1%} of tokens attended")
print(f"For real models at T=4096, k=2048: 50% sparsity")
print(f"For real models at T=8192, k=2048: 25% sparsity")

## Step 1: Input and QKV Projections

Standard attention projections. Unlike GDN, GSA uses symmetric dimensions (Q, K, V all have dim D).

In [None]:
# Input
x = torch.randn(B, T, D)
print(f"Input x shape: {x.shape}")

# Standard QKV projections
q_proj = nn.Linear(D, D, bias=False)
k_proj = nn.Linear(D, D, bias=False)
v_proj = nn.Linear(D, D, bias=False)

q = q_proj(x).view(B, T, H, d_head)  # [1, 8, 2, 16]
k = k_proj(x).view(B, T, H, d_head)  # [1, 8, 2, 16]
v = v_proj(x).view(B, T, H, d_head)  # [1, 8, 2, 16]

print(f"\nQ shape: {q.shape}")
print(f"K shape: {k.shape}")
print(f"V shape: {v.shape}")
print(f"\nAll projections have the same dimension D={D} (symmetric)")

## Step 2: Value Gate (G2)

The **Value Gate** modulates values *before* attention:

```
v_gated = v * sigmoid(W_gv @ x + b_gv)
```

This allows the model to suppress or amplify value vectors based on the input context.
The bias is initialized to 0.5, so `sigmoid(0.5) ≈ 0.62` — values are slightly dampened by default.

In [None]:
# Value Gate (G2)
value_gate_proj = nn.Linear(D, D, bias=True)
nn.init.constant_(value_gate_proj.bias, 0.5)  # ref: bias_init=0.5

# Compute gate
gate_logits = value_gate_proj(x)  # [B, T, D]
value_gate = torch.sigmoid(gate_logits).view(B, T, H, d_head)

print(f"Value gate shape: {value_gate.shape}")
print(f"\nGate values (head 0, first 4 tokens, first 4 dims):")
print(value_gate[0, :4, 0, :4].numpy().round(3))

print(f"\nMean gate value: {value_gate.mean():.3f} (initialized near sigmoid(0.5) ≈ 0.622)")

# Apply gate
v_before_gate = v.clone()
v = v * value_gate

print(f"\nV norms before gate: {v_before_gate[0, :, 0].norm(dim=-1).numpy().round(3)}")
print(f"V norms after  gate: {v[0, :, 0].norm(dim=-1).numpy().round(3)}")
print(f"\nGate dampens values → prevents attention from distributing too much weight")

## Step 3: Rotary Position Embeddings (RoPE)

Same as GDN — standard RoPE with `rotate_half`. Applied to Q and K (not V).

In [None]:
# Build RoPE
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
t = torch.arange(T, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().unsqueeze(0).unsqueeze(2)  # [1, T, 1, d_head]
sin = emb.sin().unsqueeze(0).unsqueeze(2)

def rotate_half(x):
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

q = q * cos + rotate_half(q) * sin
k = k * cos + rotate_half(k) * sin

print(f"RoPE applied to Q and K (not V).")
print(f"Q shape after RoPE: {q.shape}")
print(f"K shape after RoPE: {k.shape}")
print(f"V shape (unchanged): {v.shape}")

## Step 4: Gated Lightning Indexer

This is the **core innovation of GSA**. The indexer computes importance scores for every (query, key) pair using a *cheap, separate* set of projections.

### Formula

```
I(t, s) = Σ_h  σ(w_t^h) × σ(q_I^h · k_I_s / √d_idx + b^h)
```

Where:
- `q_I`: indexer queries [B, T, n_idx_heads, d_indexer] — separate from attention Q!
- `k_I`: indexer keys [B, T, d_indexer] — shared across indexer heads
- `w`: importance weights [B, T, n_idx_heads] — query-dependent
- `b`: learnable bias per indexer head

**Why is this efficient?** The indexer uses `d_indexer=64` (vs `d_head=128`), so the O(T²) scoring is done in a much cheaper space.

In [None]:
# Indexer projections
idx_q_proj = nn.Linear(D, n_idx_heads * d_indexer, bias=False)  # -> [B, T, n_idx_heads * d_idx]
idx_k_proj = nn.Linear(D, d_indexer, bias=False)                # -> [B, T, d_idx] (shared)
idx_w_proj = nn.Linear(D, n_idx_heads, bias=True)               # -> [B, T, n_idx_heads]
idx_bias = nn.Parameter(torch.zeros(n_idx_heads))                # [n_idx_heads]

# Xavier init with specific gains (ref)
nn.init.xavier_uniform_(idx_q_proj.weight, gain=1.0)
nn.init.xavier_uniform_(idx_k_proj.weight, gain=1.0)
nn.init.xavier_uniform_(idx_w_proj.weight, gain=0.1)  # Small gain for weights

print(f"Indexer projections:")
print(f"  q_I: [{D}] → [{n_idx_heads} × {d_indexer}] = {n_idx_heads * d_indexer}")
print(f"  k_I: [{D}] → [{d_indexer}] (shared across indexer heads)")
print(f"  w:   [{D}] → [{n_idx_heads}] (importance weights)")
print(f"  b:   [{n_idx_heads}] (learnable bias)")
print(f"\nFLOP cost: ~{2 * T * T * d_indexer * n_idx_heads:,} for T²×d_idx scoring")
print(f"vs attention: ~{2 * T * T * d_head * H:,} for T²×d_head×H")
print(f"Indexer is {d_indexer * n_idx_heads / (d_head * H):.1%} the cost of full attention QK")

In [None]:
# Compute indexer scores step by step
scale_idx = 1.0 / math.sqrt(d_indexer)

# Project to indexer space
q_idx = idx_q_proj(x).view(B, T, n_idx_heads, d_indexer)  # [1, 8, 2, 8]
k_idx = idx_k_proj(x)  # [1, 8, 8] — shared across heads!

print(f"Indexer Q shape: {q_idx.shape} = [B, T, n_idx_heads, d_indexer]")
print(f"Indexer K shape: {k_idx.shape} = [B, T, d_indexer] (shared!)")

# Raw dot products per indexer head: [B, n_heads, T_q, T_kv]
raw_scores = torch.einsum('bqhd,bkd->bhqk', q_idx.float(), k_idx.float()) * scale_idx
print(f"\nRaw score matrix shape: {raw_scores.shape} = [B, n_idx_heads, T, T]")
print(f"\nRaw scores (head 0):")
print(raw_scores[0, 0].numpy().round(2))

In [None]:
# Sigmoid gating with learnable bias
bias_exp = idx_bias.float().view(1, -1, 1, 1)
gated_scores = torch.sigmoid(raw_scores + bias_exp)

print("Sigmoid gating: σ(raw_score + bias)")
print(f"Bias values: {idx_bias.data.numpy().round(3)}")
print(f"\nGated scores (head 0) — values in (0, 1):")
print(gated_scores[0, 0].numpy().round(3))

# Importance weights: sigmoid(w_proj(x))
w = torch.sigmoid(idx_w_proj(x).float())  # [B, T, n_idx_heads]
w_exp = w.permute(0, 2, 1).unsqueeze(-1)  # [B, n_heads, T, 1]

print(f"\nImportance weights (per query position, per head):")
print(w[0].numpy().round(3))

# Weighted sum across indexer heads → final scores
weighted = gated_scores * w_exp
final_scores = weighted.sum(dim=1)  # [B, T, T]

# Apply causal mask
causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
final_scores = final_scores.masked_fill(causal_mask.unsqueeze(0), float('-inf'))

print(f"\nFinal indexer scores [B, T, T] (with causal mask):")
print(final_scores[0].numpy().round(3))

## Step 5: Adaptive Top-K Selection

From the T×T score matrix, we select the **top-k** most important tokens for each query.

The reference uses **variance-based adaptive k**: higher variance in scores → more focused attention → fewer tokens needed.

For our tiny example, we'll use a fixed k=4 out of 8 tokens.

In [None]:
# Top-K selection
k_effective = min(k_select, T)  # 4

# Replace -inf with large negative for topk
scores_for_topk = final_scores.masked_fill(final_scores == float('-inf'), -1e9)

# Select top-k indices per query position
topk_values, indices = torch.topk(scores_for_topk, k_effective, dim=-1)

print(f"Top-{k_effective} selection from {T} tokens per query:")
print(f"\nIndices shape: {indices.shape} = [B, T, k_select]")
print(f"\nSelected token indices per query position:")
for t_pos in range(T):
    idx = indices[0, t_pos].numpy()
    vals = topk_values[0, t_pos].numpy().round(3)
    print(f"  Query t={t_pos}: indices={idx}, scores={vals}")

# Build validity mask
# In practice, causal constraint means early positions have fewer valid keys
gathered_scores = torch.gather(final_scores, -1, indices)
mask = gathered_scores != float('-inf')

print(f"\nValidity mask (True=valid token):")
for t_pos in range(T):
    print(f"  Query t={t_pos}: {mask[0, t_pos].numpy()}")
print(f"\nNote: Early positions have fewer valid tokens due to causal constraint")

## Step 6: Sparse Attention with Gather

Now we perform **standard attention** but only on the selected k tokens.

### The Gather Operation

The key trick is `torch.gather` to fetch only the selected K/V vectors:

```python
# indices: [B, T, k]  (which tokens to attend to)
# K, V:    [B, T_kv, H, D]

# 1. Expand: [B, T_kv, H, D] → [B, T, T_kv, H, D]  (repeat for each query)
# 2. Gather: pick k entries from T_kv → [B, T, k, H, D]
# 3. Attend: softmax(Q @ K_gathered^T / √d) @ V_gathered
```

The indices are **shared across attention heads** — same tokens selected for all heads.

In [None]:
# Sparse Attention — step by step
scale_attn = 1.0 / math.sqrt(d_head)

# Clamp indices to valid range
idx = indices.clamp(0, T - 1).long()  # [B, T, k_select]

print("Step 6a: Gather K and V using selected indices")
print(f"  K shape: {k.shape} = [B, T_kv, H, D]")
print(f"  indices shape: {idx.shape} = [B, T, k_select]")

# Expand indices for gather: [B, T, k] → [B, T, k, H, D]
idx_exp = idx.unsqueeze(-1).unsqueeze(-1).expand(B, T, k_effective, H, d_head)

# Expand K, V: [B, T_kv, H, D] → [B, T, T_kv, H, D]
k_expanded = k.unsqueeze(1).expand(B, T, T, H, d_head)
v_expanded = v.unsqueeze(1).expand(B, T, T, H, d_head)

# Gather selected tokens
k_gathered = torch.gather(k_expanded, 2, idx_exp)  # [B, T, k, H, D]
v_gathered = torch.gather(v_expanded, 2, idx_exp)  # [B, T, k, H, D]

print(f"  K gathered: {k_gathered.shape} = [B, T, k_select, H, D]")
print(f"  V gathered: {v_gathered.shape} = [B, T, k_select, H, D]")

# Permute for attention: [B, T, H, k, D]
k_gathered = k_gathered.permute(0, 1, 3, 2, 4)  # [B, T, H, k, D]
v_gathered = v_gathered.permute(0, 1, 3, 2, 4)
print(f"  Permuted for attention: {k_gathered.shape} = [B, T, H, k_select, D]")

In [None]:
# Compute attention scores
print("Step 6b: Compute attention scores on selected tokens")

# Q: [B, T, H, D] @ K_gathered: [B, T, H, k, D]^T → [B, T, H, k]
attn_scores = torch.einsum('bqhd,bqhkd->bqhk', q, k_gathered) * scale_attn
print(f"  Attention scores shape: {attn_scores.shape} = [B, T, H, k_select]")

# Apply mask: [B, T, k] → [B, T, 1, k]
mask_exp = mask.unsqueeze(2)  # broadcast across heads
attn_scores = attn_scores.masked_fill(~mask_exp, float('-inf'))

# Softmax over k_select dimension
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.masked_fill(~mask_exp, 0.0)
attn_weights = attn_weights.nan_to_num(0.0)

print(f"\nAttention weights (query=3, head=0) — sums to 1.0 over selected tokens:")
w = attn_weights[0, 3, 0].numpy().round(3)
idx_at_3 = indices[0, 3].numpy()
for j in range(k_effective):
    print(f"  Token {idx_at_3[j]}: weight={w[j]:.3f}  {'(masked)' if w[j] == 0 else ''}")
print(f"  Sum: {w.sum():.3f}")

# Weighted sum: [B, T, H, k] @ [B, T, H, k, D] → [B, T, H, D]
attn_output = torch.einsum('bqhk,bqhkd->bqhd', attn_weights, v_gathered)
print(f"\nAttention output shape: {attn_output.shape} = [B, T, H, D]")

## Step 7: Output Gate (G1)

After sparse attention, the output is modulated by the **Output Gate**:

```
o_gated = attn_output * sigmoid(W_go @ x + b_go)
```

This provides a second level of control — the model can learn to suppress the attention output for certain positions.

In [None]:
# Output Gate (G1)
output_gate_proj = nn.Linear(D, D, bias=True)
nn.init.constant_(output_gate_proj.bias, 0.5)

output_gate = torch.sigmoid(output_gate_proj(x)).view(B, T, H, d_head)

print(f"Output gate shape: {output_gate.shape}")
print(f"Mean gate value: {output_gate.mean():.3f}")

# Apply gate
attn_before_gate = attn_output.clone()
attn_output = attn_output * output_gate

print(f"\nOutput norms before gate: {attn_before_gate[0, :, 0].norm(dim=-1).numpy().round(3)}")
print(f"Output norms after  gate: {attn_output[0, :, 0].norm(dim=-1).numpy().round(3)}")

print(f"\nDual gating summary:")
print(f"  G2 (ValueGate):  v = v * σ(W_gv·x + b)  — controls what info values carry")
print(f"  G1 (OutputGate): o = o * σ(W_go·x + b)   — controls what info to output")

## Step 8: Output Projection

Reshape heads back together and project to model dimension.

In [None]:
# Reshape and project
attn_output = attn_output.reshape(B, T, D)
o_proj = nn.Linear(D, D, bias=False)
final_output = o_proj(attn_output)

print(f"Final output shape: {final_output.shape} = [B, T, D]")
print(f"\nMatches input shape — GSA is a drop-in replacement for standard attention!")

## Comparing Sparse vs Full Attention

Let's verify: does sparse attention give a reasonable approximation of full attention?

In [None]:
# Full causal attention for comparison
def full_causal_attention(q, k, v, scale):
    """Standard full causal attention: [B, T, H, D]"""
    scores = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
    causal = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
    scores = scores.masked_fill(causal.unsqueeze(0).unsqueeze(0), float('-inf'))
    weights = torch.softmax(scores, dim=-1)
    return torch.einsum('bhqk,bkhd->bqhd', weights, v), weights

full_output, full_weights = full_causal_attention(q, k, v, scale_attn)

print(f"Full attention output shape: {full_output.shape}")
print(f"Sparse attention output shape: {attn_before_gate.shape}")

# Compare
diff = (full_output - attn_before_gate).abs()
print(f"\nMean absolute difference: {diff.mean():.4f}")
print(f"Max absolute difference: {diff.max():.4f}")

# Show which tokens full attention focuses on vs sparse
print(f"\nFull attention weights (query=5, head=0):")
fw = full_weights[0, 0, 5].numpy().round(3)
for j in range(T):
    selected = '  ◄ SELECTED' if j in indices[0, 5].numpy() else ''
    print(f"  Token {j}: weight={fw[j]:.3f}{selected}")
print(f"\nThe indexer should learn to select the tokens with highest attention weight!")

## Adaptive Top-K: How It Works

The reference uses **variance-based** adaptive k:

- High variance in indexer scores → attention is very focused → fewer tokens suffice
- Low variance → attention is spread out → more tokens needed

```
k_adaptive = k_base × avg_variance / position_variance
```

In [None]:
# Demonstrate adaptive k computation
k_base = 4  # base number of tokens
k_min = 2
k_max = 6

# Use the indexer scores
valid_mask = final_scores != float('-inf')
valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1).float()
scores_valid = final_scores.masked_fill(~valid_mask, 0.0)

# Per-position variance
mean = scores_valid.sum(dim=-1, keepdim=True) / valid_count
diff = (scores_valid - mean * valid_mask.float()).masked_fill(~valid_mask, 0.0)
variance = diff.pow(2).sum(dim=-1) / valid_count.squeeze(-1)

avg_var = variance.mean().clamp(min=1e-6)

# k_adaptive = k_base × avg_var / pos_var
k_adaptive = (k_base * avg_var / variance.clamp(min=1e-6)).floor().clamp(min=k_min, max=k_max).long()

print(f"Per-position variance of indexer scores:")
print(f"  {variance[0].numpy().round(4)}")
print(f"\nAverage variance: {avg_var.item():.4f}")
print(f"\nAdaptive k per query position:")
print(f"  {k_adaptive[0].numpy()}")
print(f"\nInterpretation:")
for t_pos in range(T):
    var = variance[0, t_pos].item()
    k_val = k_adaptive[0, t_pos].item()
    focus = 'focused' if var > avg_var.item() else 'spread'
    print(f"  t={t_pos}: var={var:.4f} ({focus}) → k={k_val}")

## Visualizing the Sparse Attention Pattern

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# 1. Full attention weights
fw_np = full_weights[0, 0].numpy()
im1 = axes[0].imshow(fw_np, cmap='Blues', aspect='auto')
axes[0].set_title('Full Causal Attention\n(Head 0)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0])

# 2. Indexer scores (before topk)
idx_scores = final_scores[0].clone()
idx_scores[idx_scores == float('-inf')] = 0
im2 = axes[1].imshow(idx_scores.numpy(), cmap='Oranges', aspect='auto')
axes[1].set_title('Indexer Scores\n(Causal Masked)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[1])

# 3. Sparse selection pattern
sparse_pattern = torch.zeros(T, T)
for t_pos in range(T):
    for j in range(k_effective):
        if mask[0, t_pos, j]:
            sparse_pattern[t_pos, indices[0, t_pos, j]] = 1.0

im3 = axes[2].imshow(sparse_pattern.numpy(), cmap='Greens', aspect='auto')
axes[2].set_title(f'Sparse Selection Pattern\n(k={k_effective} tokens selected)')
axes[2].set_xlabel('Key Position')
axes[2].set_ylabel('Query Position')
plt.colorbar(im3, ax=axes[2])

plt.suptitle('GSA: Full Attention vs Sparse Selection', fontsize=14)
plt.tight_layout()
plt.show()

# Sparsity stats
total_entries = T * (T + 1) / 2  # causal triangle
selected_entries = sparse_pattern.sum().item()
print(f"\nSparsity: {selected_entries:.0f}/{total_entries:.0f} entries = {selected_entries/total_entries:.1%} selected")

## Summary: Complete GSA Forward Pass

| Step | Operation | Shape Transform | Purpose |
|------|-----------|----------------|----------|
| 1 | q/k/v_proj | [B,T,D] → [B,T,H,d] | Standard attention projections |
| 2 | ValueGate (G2) | v = v * σ(W·x+b) | Control value information flow |
| 3 | RoPE | Same shape | Positional encoding on Q, K |
| 4 | Indexer scoring | [B,T,D] → [B,T,T] | Cheap O(T²) importance scores |
| 5 | Top-K selection | [B,T,T] → [B,T,k] indices | Select important tokens |
| 6 | Sparse attention | [B,T,H,D] × [B,T,k,H,D] → [B,T,H,D] | Attend only to selected tokens |
| 7 | OutputGate (G1) | o = o * σ(W·x+b) | Control output information flow |
| 8 | o_proj | [B,T,D] → [B,T,D] | Back to model dim |

### Complexity Comparison

| | Full Attention | GSA |
|---|---|---|
| **Indexer** | N/A | O(T² × d_indexer × n_idx_heads) |
| **Attention** | O(T² × d_head × H) | O(T × k × d_head × H) |
| **Total** | O(T² × D) | O(T² × d_idx × n_idx + T × k × D) |
| **Memory** | O(T² × H) | O(T × k × H) |

### Why It Works
1. **Indexer is cheap** — d_indexer (64) << d_head (128), so T² scoring costs much less
2. **Adaptive k** — attend to more tokens when attention is spread, fewer when focused
3. **Dual gating** — G2 controls what info enters, G1 controls what info exits
4. **Shared indices** — same tokens selected for all heads → efficient gather

### GDN vs GSA: Key Differences

| | GDN | GSA |
|---|---|---|
| **Approach** | Recurrent (state matrix) | Sparse attention (token selection) |
| **Complexity** | O(T × d_k × d_v) | O(T² × d_idx + T × k × D) |
| **Memory** | O(d_k × d_v) constant | O(T × k) linear |
| **Parallelism** | Sequential (per token) | Parallelizable (gather + matmul) |
| **Best for** | Very long sequences | Moderate sequences with local patterns |