# Gated DeltaNet (GDN) — Step-by-Step Explanation

**Reference:** [NVlabs/GatedDeltaNet](https://github.com/NVlabs/GatedDeltaNet)

This notebook walks through the entire Gated DeltaNet forward pass using a **tiny example** so you can inspect every tensor at every step.

## Architecture Overview

```
Input x [B, T, D]
  │
  ├──► q_proj ──► q_conv1d ──► reshape ──► RoPE ──► L2 norm ──► scale ──┐
  ├──► k_proj ──► k_conv1d ──► reshape ──► RoPE ──► L2 norm ──────────►│
  ├──► v_proj ──► v_conv1d ──► reshape ────────────────────────────────►│ Recurrent
  ├──► b_proj ──► sigmoid (write gate β) ──────────────────────────────►│ Delta Rule
  ├──► gk_proj ──► Mamba gate (-A.exp() * softplus(gk + dt_bias)) ────►│
  │                                                                     │
  │                                              output o ◄─────────────┘
  │                                                │
  ├──► g_proj ────────────────────────────────► FusedRMSNormSwishGate
  │                                                │
  │                                           o_proj ──► Output
```

### Key Idea: The Delta Rule

GDN maintains a **recurrent state matrix** `S` of shape `[d_k, d_v]` per head. At each time step:

1. **Decay**: `S = S * exp(g)` — forget old information
2. **Delta correction**: `v_new = (v - S @ k) * β` — compute what's *new* relative to current memory
3. **Update**: `S = S + k ⊗ v_new` — write the new information (rank-1 outer product)
4. **Read**: `o = q @ S` — query the memory

This is a **linear-time** recurrence (O(T)) compared to O(T²) for standard attention.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)
torch.set_grad_enabled(False)

# Tiny example dimensions
B = 1       # batch size
T = 6       # sequence length (small enough to inspect)
D = 32      # model dimension
H = 2       # number of heads
expand_k = 0.75  # key expansion factor (ref default)
expand_v = 1.5   # value expansion factor (ref default)

key_dim = int(D * expand_k)    # 24
value_dim = int(D * expand_v)  # 48
d_k = key_dim // H             # 12 (per-head key dim)
d_v = value_dim // H           # 24 (per-head value dim)

print(f"Model dim D={D}, Heads H={H}")
print(f"Key dim: {key_dim} (expand_k={expand_k}), per head: {d_k}")
print(f"Value dim: {value_dim} (expand_v={expand_v}), per head: {d_v}")
print(f"State matrix S per head: [{d_k} x {d_v}] = {d_k * d_v} parameters")
print(f"\nNote: key_dim != value_dim. This asymmetry is a key design choice.")

## Step 1: Input and Projections

The input `x` is projected into **Q, K, V** with different output dimensions:
- Q, K → `key_dim` (D × expand_k = 24)
- V → `value_dim` (D × expand_v = 48)

Additionally, we project **two scalar gates per head**:
- `β` (write gate) via `b_proj` → sigmoid → controls how much new info to write
- `gk` (decay gate) via `gk_proj` → Mamba-style transform → controls forgetting rate

In [None]:
# Create a simple input
x = torch.randn(B, T, D)
print(f"Input x shape: {x.shape}")
print(f"Input x (first 2 tokens, first 8 dims):\n{x[0, :2, :8]}")

# Linear projections
q_proj = nn.Linear(D, key_dim, bias=False)
k_proj = nn.Linear(D, key_dim, bias=False)
v_proj = nn.Linear(D, value_dim, bias=False)
b_proj = nn.Linear(D, H, bias=True)     # write gate
gk_proj = nn.Linear(D, H, bias=False)   # decay gate (no bias when use_mamba_gate=True)
g_proj = nn.Linear(D, value_dim, bias=False)  # output gate projection

q = q_proj(x)   # [B, T, key_dim=24]
k = k_proj(x)   # [B, T, key_dim=24]
v = v_proj(x)   # [B, T, value_dim=48]

print(f"\nQ shape: {q.shape}  (key_dim={key_dim})")
print(f"K shape: {k.shape}  (key_dim={key_dim})")
print(f"V shape: {v.shape}  (value_dim={value_dim})")
print(f"\nNote: Q,K have dim {key_dim} but V has dim {value_dim} (asymmetric!)")

## Step 2: Short Convolutions

Before the recurrence, Q, K, V pass through **depthwise 1D convolutions** (kernel=4) with SiLU activation.

This allows local context mixing before the global recurrence. Think of it as a "local receptive field" preprocessing step.

In [None]:
# Short Convolution (simplified — depthwise conv1d + SiLU)
conv_size = 4

q_conv = nn.Conv1d(key_dim, key_dim, kernel_size=conv_size, padding=conv_size-1, groups=key_dim)
k_conv = nn.Conv1d(key_dim, key_dim, kernel_size=conv_size, padding=conv_size-1, groups=key_dim)
v_conv = nn.Conv1d(value_dim, value_dim, kernel_size=conv_size, padding=conv_size-1, groups=value_dim)

def short_conv(conv, x_in, conv_size):
    """Apply causal depthwise conv + SiLU"""
    out = conv(x_in.transpose(1, 2))        # [B, C, T+pad]
    out = out[:, :, :-(conv_size - 1)]       # remove right padding for causal
    return F.silu(out).transpose(1, 2)       # [B, T, C]

q = short_conv(q_conv, q, conv_size)
k = short_conv(k_conv, k, conv_size)
v = short_conv(v_conv, v, conv_size)

print(f"After short conv + SiLU:")
print(f"  Q shape: {q.shape}")
print(f"  K shape: {k.shape}")
print(f"  V shape: {v.shape}")
print(f"\nConv is causal (only looks at current + past {conv_size-1} positions)")
print(f"SiLU activation: x * sigmoid(x) — smooth non-linearity")

## Step 3: Reshape to Multi-Head Format

Reshape Q, K, V into multi-head format: `[B, T, H, d_per_head]`

In [None]:
q = q.view(B, T, H, d_k)   # [1, 6, 2, 12]
k = k.view(B, T, H, d_k)   # [1, 6, 2, 12]
v = v.view(B, T, H, d_v)   # [1, 6, 2, 24]

print(f"Multi-head shapes:")
print(f"  Q: {q.shape}  — H={H} heads, d_k={d_k} per head")
print(f"  K: {k.shape}  — H={H} heads, d_k={d_k} per head")
print(f"  V: {v.shape}  — H={H} heads, d_v={d_v} per head")
print(f"\nQ,K heads have dimension {d_k} but V heads have dimension {d_v}")

## Step 4: Rotary Position Embeddings (RoPE)

RoPE encodes position information by **rotating** Q and K vectors in 2D subspaces.

Formula: `x_rotated = x * cos(θ) + rotate_half(x) * sin(θ)`

Where `rotate_half` splits the vector in half and swaps with negation: `[-x2, x1]`

In [None]:
# Build RoPE frequencies
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, d_k, 2).float() / d_k))
print(f"inv_freq (one per 2D rotation plane): {inv_freq}")
print(f"Number of rotation planes: {len(inv_freq)} = d_k/2 = {d_k}//2")

# Position-dependent angles
t = torch.arange(T, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)  # [T, d_k]
cos_cached = emb.cos().unsqueeze(0).unsqueeze(2)  # [1, T, 1, d_k]
sin_cached = emb.sin().unsqueeze(0).unsqueeze(2)

print(f"\nAngles for each position (first 3 planes):")
for pos in range(min(4, T)):
    angles = freqs[pos, :3]
    print(f"  pos={pos}: {angles.numpy()} radians")

# Apply RoPE
def rotate_half(x):
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

q_before_rope = q.clone()
q = q * cos_cached + rotate_half(q) * sin_cached
k = k * cos_cached + rotate_half(k) * sin_cached

print(f"\nQ before RoPE (head 0, token 0, first 6 dims): {q_before_rope[0, 0, 0, :6].numpy().round(3)}")
print(f"Q after  RoPE (head 0, token 0, first 6 dims): {q[0, 0, 0, :6].numpy().round(3)}")
print(f"\nRoPE makes dot(q_i, k_j) depend on relative position (i-j)!")

## Step 5: L2 Normalization

Q and K are L2-normalized so that their dot products are bounded in [-1, 1].

This stabilizes the recurrence — without normalization, the state matrix `S` can grow unboundedly.

In [None]:
q_before_norm = q.clone()
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)

print(f"Before L2 norm — Q norms per head, token:")
print(f"  {q_before_norm[0, :, 0, :].norm(dim=-1).numpy().round(3)}")

print(f"\nAfter L2 norm — all norms = 1.0:")
print(f"  {q[0, :, 0, :].norm(dim=-1).numpy().round(3)}")

print(f"\nThis bounds dot(q, k) ∈ [-1, 1] — critical for stable recurrence.")

## Step 6: Gate Computation

Two gates control the recurrence:

### Write Gate β (beta)
- `β = sigmoid(b_proj(x))` — scalar per head ∈ (0, 1)
- Controls how strongly new information is written to the state

### Decay Gate g (Mamba-style)
- `g = -A.exp() * softplus(gk_proj(x) + dt_bias)` — always **negative** (log-space decay)
- `exp(g)` ∈ (0, 1) controls how much old state is retained
- `-A.exp()` ensures decay is always positive (forgetting)
- `softplus(...)` ensures the rate magnitude is always positive

In [None]:
# Write gate β
beta_raw = b_proj(x)  # [B, T, H]
beta = beta_raw.float().sigmoid().transpose(1, 2)  # [B, H, T]

print("Write gate β (sigmoid output per head, per token):")
print(f"  Shape: {beta.shape} = [B, H, T]")
for h in range(H):
    print(f"  Head {h}: {beta[0, h].numpy().round(3)}")

# Decay gate g (Mamba-style)
# Initialize A_log and dt_bias as in reference
A_log = nn.Parameter(torch.log(torch.tensor([4.0, 8.0])))  # 2 heads
dt_min, dt_max = 0.001, 0.1
dt = torch.exp(torch.rand(H) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt)))

gk_raw = gk_proj(x).float()  # [B, T, H]
A = A_log.float().exp()
gk = -A * F.softplus(gk_raw + dt_bias)  # Always negative!
gk = gk.transpose(1, 2)  # [B, H, T]

print(f"\nDecay gate g (log-space, always negative):")
for h in range(H):
    print(f"  Head {h}: {gk[0, h].detach().numpy().round(4)}")
print(f"\nexp(g) = retention ratio:")
for h in range(H):
    print(f"  Head {h}: {gk[0, h].exp().detach().numpy().round(4)}")
print(f"\nValues near 1.0 = keep memory; near 0.0 = forget everything")

## Step 7: Prepare for Recurrence

Transpose to `[B, H, T, d]` format and apply query scaling by `d_k^(-0.5)`.

In [None]:
# Transpose: [B, T, H, d] -> [B, H, T, d]
q = q.transpose(1, 2)  # [B, H, T, d_k]
k = k.transpose(1, 2)  # [B, H, T, d_k]
v = v.transpose(1, 2)  # [B, H, T, d_v]

# Scale queries
scale = d_k ** -0.5
q = q * scale

print(f"Shapes for recurrence:")
print(f"  Q: {q.shape} (scaled by {scale:.4f} = 1/√{d_k})")
print(f"  K: {k.shape}")
print(f"  V: {v.shape}")
print(f"  β: {beta.shape}")
print(f"  g: {gk.shape}")

## Step 8: The Delta Rule Recurrence (Core Algorithm)

This is the **heart of GDN**. We process tokens one by one, maintaining a state matrix `S`.

For each time step `i`:

```
1. DECAY:   S = S * exp(g_i)                    # Forget old info
2. RECALL:  recalled = (S * k_i).sum(dim=-2)    # What does memory say about k_i?
3. DELTA:   v_new = (v_i - recalled) * β_i      # What's genuinely new?
4. WRITE:   S = S + k_i ⊗ v_new                 # Store the new info (rank-1 update)
5. READ:    o_i = q_i @ S                        # Query the memory
```

The **delta correction** (step 2-3) is what makes this different from a simple linear RNN:
- It checks what the memory *already knows* about key `k_i`
- Only writes the *difference* (delta) between the new value and the recalled value
- This prevents redundant writing and improves memory utilization

In [None]:
# Delta Rule Recurrence — step by step for ONE head
head = 0  # Let's trace head 0

# Extract single head tensors
q_h = q[0, head].float()    # [T, d_k]
k_h = k[0, head].float()    # [T, d_k]
v_h = v[0, head].float()    # [T, d_v]
beta_h = beta[0, head].float()  # [T]
g_h = gk[0, head].float()      # [T]

# Initialize state matrix
S = torch.zeros(d_k, d_v)  # The memory!
o_h = torch.zeros(T, d_v)  # Output accumulator

print(f"State matrix S shape: [{d_k}, {d_v}] = {d_k*d_v} parameters")
print(f"Processing {T} tokens sequentially...\n")
print("=" * 80)

for i in range(T):
    _k = k_h[i]        # [d_k] — current key
    _q = q_h[i]        # [d_k] — current query  
    _v = v_h[i].clone() # [d_v] — current value
    _beta = beta_h[i]   # scalar — write gate
    _g = g_h[i]         # scalar — decay gate (log-space)
    
    print(f"\n--- Token {i} ---")
    print(f"  β (write gate) = {_beta.item():.4f}")
    print(f"  g (decay, log)  = {_g.item():.4f}, exp(g) = {_g.exp().item():.4f}")
    
    # 1. DECAY: Forget old information
    S_before_decay = S.clone()
    decay = _g.exp()  # scalar in (0, 1)
    S = S * decay
    print(f"  1. DECAY: S *= {decay.item():.4f}")
    print(f"     S Frobenius norm: {S_before_decay.norm():.4f} → {S.norm():.4f}")
    
    # 2. RECALL: What does memory already know about this key?
    recalled = (S * _k[..., None]).sum(dim=-2)  # [d_v]
    print(f"  2. RECALL: || recalled || = {recalled.norm():.4f}")
    
    # 3. DELTA: Only write what's genuinely new
    delta = _v - recalled
    v_new = delta * _beta
    print(f"  3. DELTA: || v_original || = {_v.norm():.4f}")
    print(f"            || v - recalled || = {delta.norm():.4f}")
    print(f"            || v_new (delta * β) || = {v_new.norm():.4f}")
    
    # 4. WRITE: Rank-1 update to state
    S_before_write = S.clone()
    S = S + _k.unsqueeze(-1) * v_new.unsqueeze(-2)  # outer product
    print(f"  4. WRITE: S += k ⊗ v_new (rank-1 update)")
    print(f"     S Frobenius norm: {S_before_write.norm():.4f} → {S.norm():.4f}")
    
    # 5. READ: Query the memory
    o_h[i] = torch.einsum('d,dm->m', _q, S)  # [d_v]
    print(f"  5. READ: || output || = {o_h[i].norm():.4f}")

print("\n" + "=" * 80)
print(f"\nFinal state S norm: {S.norm():.4f}")
print(f"Output shape: {o_h.shape}")

## Step 9: Full Recurrence (All Heads)

Now let's run the complete `recurrent_gated_delta_rule_ref` on all heads simultaneously.

In [None]:
def recurrent_gated_delta_rule_ref(q, k, v, beta, g):
    """Reference recurrence from NVlabs/GatedDeltaNet"""
    q, k, v, beta, g = map(lambda x: x.to(torch.float32), [q, k, v, beta, g])
    b, h, l, d_k = q.shape
    d_v = v.shape[-1]
    o = torch.zeros_like(v)
    S = torch.zeros(b, h, d_k, d_v).to(v)
    
    for i in range(l):
        _k = k[:, :, i]
        _q = q[:, :, i]
        _v = v[:, :, i].clone()
        S = S.clone() * g[:, :, i].exp()[..., None, None]
        beta_i = beta[:, :, i]
        _v = _v - (S.clone() * _k[..., None]).sum(-2)
        _v = _v * beta_i[..., None]
        S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
    
    return o

# Run on all heads
output = recurrent_gated_delta_rule_ref(q, k, v, beta, gk.detach())
print(f"Full recurrence output shape: {output.shape} = [B, H, T, d_v]")

# Verify head 0 matches our step-by-step
diff = (output[0, head] - o_h).abs().max()
print(f"\nMax difference between step-by-step and batched (head {head}): {diff:.6e}")
print(f"Match: {'YES ✓' if diff < 1e-5 else 'NO ✗'}")

## Step 10: Output Normalization with Swish Gate

The output goes through `FusedRMSNormSwishGate`:

```
output = gate * SiLU(RMSNorm(recurrence_output))
```

Where:
- `RMSNorm(x) = x / sqrt(mean(x²) + ε) * weight`
- `SiLU(x) = x * sigmoid(x)` (smooth activation)
- `gate = g_proj(x)` (learned multiplicative gate)

In [None]:
# Transpose back: [B, H, T, d_v] -> [B, T, H, d_v]
o = output.transpose(1, 2)

# Compute gate from original input
g = g_proj(x).view(B, T, H, d_v)

print(f"Recurrence output o: {o.shape}")
print(f"Gate g: {g.shape}")

# RMSNorm
norm_weight = nn.Parameter(torch.ones(d_v))
eps = 1e-6

# Process per-token, per-head
o_flat = o.reshape(B * T * H, d_v)
g_flat = g.reshape(B * T * H, d_v)

# RMSNorm
variance = o_flat.pow(2).mean(-1, keepdim=True)
o_normed = o_flat * torch.rsqrt(variance + eps) * norm_weight

# Swish gate: gate * SiLU(normalized_output)
o_gated = g_flat * F.silu(o_normed)

print(f"\nAfter RMSNorm: mean variance = {variance.mean():.4f}")
print(f"After SiLU + gate: shape = {o_gated.shape}")

# Reshape back
o_final = o_gated.view(B, T, value_dim)
print(f"\nReshaped output: {o_final.shape} = [B, T, value_dim={value_dim}]")

## Step 11: Output Projection

Finally, project back from `value_dim` to `hidden_size`.

In [None]:
o_proj = nn.Linear(value_dim, D, bias=False)
final_output = o_proj(o_final)

print(f"Output projection: [{value_dim}] → [{D}]")
print(f"Final output shape: {final_output.shape} = [B, T, D]")
print(f"\nThis matches the input shape — GDN is a drop-in replacement for attention!")

## Summary: Complete GDN Forward Pass

| Step | Operation | Shape Transform | Purpose |
|------|-----------|----------------|----------|
| 1 | q/k/v_proj | [B,T,D] → [B,T,key_dim/value_dim] | Project to QKV |
| 2 | Short Conv + SiLU | Same shape | Local context mixing |
| 3 | Reshape | → [B,T,H,d_k/d_v] | Multi-head format |
| 4 | RoPE | Same shape | Positional encoding |
| 5 | L2 Norm | Same shape | Stabilize recurrence |
| 6 | Gates (β, g) | [B,T,D] → [B,H,T] | Control write/decay |
| 7 | Transpose | → [B,H,T,d] | Head-first for recurrence |
| 8 | Scale Q | Same shape | Normalize attention |
| 9 | Delta Rule | [B,H,T,d_k]×State → [B,H,T,d_v] | **Core recurrence** |
| 10 | RMSNorm+SwishGate | → [B,T,value_dim] | Output normalization |
| 11 | o_proj | → [B,T,D] | Back to model dim |

### Complexity
- **Time**: O(T × d_k × d_v) per head — **linear in sequence length**
- **Space**: O(d_k × d_v) per head for the state matrix — **constant in sequence length**
- **vs Standard Attention**: O(T²×d) time and O(T²) space

### Why It Works
1. **Delta correction** prevents redundant writes → better memory utilization
2. **Gated decay** allows selective forgetting → handles long-range dependencies
3. **L2 normalization** bounds state growth → stable training
4. **Asymmetric key/value dims** (expand_k=0.75, expand_v=1.5) → compact keys, rich values

In [None]:
# Visualize the state matrix evolution
import matplotlib.pyplot as plt

# Re-run recurrence tracking state norms
q_h = q[0, 0].float()
k_h = k[0, 0].float()
v_h = v[0, 0].float()
beta_h = beta[0, 0].float()
g_h = gk[0, 0].detach().float()

S = torch.zeros(d_k, d_v)
state_norms = []
decay_factors = []
delta_norms = []

for i in range(T):
    _k, _q, _v = k_h[i], q_h[i], v_h[i].clone()
    _beta, _g = beta_h[i], g_h[i]
    
    decay = _g.exp()
    S = S * decay
    recalled = (S * _k[..., None]).sum(dim=-2)
    delta = _v - recalled
    v_new = delta * _beta
    S = S + _k.unsqueeze(-1) * v_new.unsqueeze(-2)
    
    state_norms.append(S.norm().item())
    decay_factors.append(decay.item())
    delta_norms.append(delta.norm().item())

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].bar(range(T), state_norms, color='steelblue')
axes[0].set_xlabel('Token Position')
axes[0].set_ylabel('Frobenius Norm')
axes[0].set_title('State Matrix ||S|| Over Time')

axes[1].bar(range(T), decay_factors, color='coral')
axes[1].set_xlabel('Token Position')
axes[1].set_ylabel('exp(g)')
axes[1].set_title('Decay Factor (Retention)')
axes[1].set_ylim(0, 1.1)

axes[2].bar(range(T), delta_norms, color='mediumseagreen')
axes[2].set_xlabel('Token Position')
axes[2].set_ylabel('||v - recalled||')
axes[2].set_title('Delta Norm (New Information)')

plt.suptitle('GDN State Evolution — Head 0', fontsize=14)
plt.tight_layout()
plt.show()