# Benchmark: Gated DeltaNet vs Gated Sparse Attention\n
Comparison of throughput and memory usage.\n
**Reference Implementations:**\n
- GDN: [NVlabs/GatedDeltaNet](https://github.com/NVlabs/GatedDeltaNet)\n
- GSA: [alfredcs/Gated-Sparse-Attention](https://github.com/alfredcs/Gated-Sparse-Attention)

In [None]:

# Imports & Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import matplotlib.pyplot as plt
import numpy as np
import logging
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Benchmark")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mock HAS_TRITON to False for portability (or check if installed)
try:
    import triton
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("Triton not found. Using PyTorch fallbacks.")

# Disable gradients for benchmarking
torch.set_grad_enabled(False)


In [None]:
# ============================================================================
# Kernel Fallbacks
# Aligned with:
#   GSA: https://github.com/alfredcs/Gated-Sparse-Attention/blob/main/gsa/kernels/
#   GDN: https://github.com/NVlabs/GatedDeltaNet/blob/main/lit_gpt/gated_delta_rule_ops/
# ============================================================================

def pytorch_gated_indexer(
    q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, b: torch.Tensor,
    scale: float = 1.0, causal: bool = True, q_offset: int = 0
) -> torch.Tensor:
    """
    PyTorch fallback for GatedLightningIndexer scoring.
    Matches: gsa/kernels/triton_indexer.py -> triton_gated_indexer()
    
    Formula per ref: score[q,k] = sum_h( sigmoid(w[q,h]) * sigmoid(dot(q_I[q,h], k_I[k]) * scale + b[h]) )
    
    Args:
        q: [batch, seq_q, n_heads, d_idx] - indexer queries
        k: [batch, seq_kv, d_idx] - indexer keys (shared across heads)
        w: [batch, seq_q, n_heads] - query-dependent importance weights
        b: [n_heads] - learnable bias per indexer head
        scale: 1/sqrt(d_idx)
    Returns:
        scores: [batch, seq_q, seq_kv]
    """
    batch_size, seq_q, n_heads, d_idx = q.shape
    seq_kv = k.shape[1]

    q = q.float()
    k = k.float()
    
    # QK dot product per indexer head: [batch, n_heads, seq_q, seq_kv]
    raw_scores = torch.einsum('bqhd,bkd->bhqk', q, k) * scale

    # Sigmoid gating with learnable bias (ref: sigmoid(dot + b))
    bias_expanded = b.float().view(1, -1, 1, 1)
    gated_scores = torch.sigmoid(raw_scores + bias_expanded)

    # Query-dependent importance: sigmoid(w) (ref: weight_proj -> sigmoid)
    w_sigmoid = torch.sigmoid(w.float()).permute(0, 2, 1).unsqueeze(-1)  # [batch, n_heads, seq_q, 1]

    # Weighted sum across indexer heads (ref: weighted_scores.sum(dim=1))
    weighted_scores = gated_scores * w_sigmoid
    final_scores = weighted_scores.sum(dim=1)  # [batch, seq_q, seq_kv]

    # Causal mask (ref: torch.triu diagonal=1 mask)
    if causal:
        causal_mask = torch.triu(
            torch.ones(seq_q, seq_kv, device=q.device, dtype=torch.bool),
            diagonal=1 + q_offset
        )
        if q_offset > 0:
            query_positions = q_offset + torch.arange(seq_q, device=q.device)
            key_positions = torch.arange(seq_kv, device=q.device)
            causal_mask = key_positions.unsqueeze(0) > query_positions.unsqueeze(1)
        final_scores = final_scores.masked_fill(causal_mask.unsqueeze(0), float('-inf'))

    return final_scores

def pytorch_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
    """RMSNorm matching lit_gpt/rmsnorm.py"""
    if residual is not None:
        x = x + residual
    variance = x.pow(2).mean(-1, keepdim=True)
    x_normed = x * torch.rsqrt(variance + eps)
    return x_normed * weight

def pytorch_sparse_attention(q, k, v, indices, mask, scale):
    """
    PyTorch fallback for sparse attention.
    Matches: gsa/kernels/triton_sparse_attn.py -> _pytorch_sparse_attention()
    
    Reference uses gather along seq dim with 5D expand:
        indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(B, T, k_sel, H, D)
        x_expanded = x.unsqueeze(1).expand(B, T, seq_kv, H, D)
        gathered = torch.gather(x_expanded, 2, indices_expanded)
    
    Args:
        q: [B, T, H, D] query tensor
        k: [B, T_kv, H, D] key tensor
        v: [B, T_kv, H, D] value tensor
        indices: [B, T, k_selected] indices for token selection (shared across heads)
        mask: [B, T, k_selected] boolean mask (True=valid)
        scale: attention scale factor
    Returns:
        output: [B, T, H, D]
    """
    B, T, H, D = q.shape
    T_kv = k.shape[1]
    k_selected = indices.shape[-1]
    
    # Clamp indices to valid range
    idx = indices.clamp(0, T_kv - 1).long()  # [B, T, k_selected]
    
    # Gather K and V along sequence dimension (matching ref _gather_along_seq)
    # idx: [B, T, k_sel] -> [B, T, k_sel, H, D]
    idx_exp = idx.unsqueeze(-1).unsqueeze(-1).expand(B, T, k_selected, H, D)
    
    # k: [B, T_kv, H, D] -> [B, T, T_kv, H, D] (expand for each query position)
    k_expanded = k.unsqueeze(1).expand(B, T, T_kv, H, D)
    v_expanded = v.unsqueeze(1).expand(B, T, T_kv, H, D)
    
    # Gather: [B, T, k_selected, H, D]
    k_gathered = torch.gather(k_expanded, 2, idx_exp)
    v_gathered = torch.gather(v_expanded, 2, idx_exp)
    
    # Permute for attention: [B, T, H, k_selected, D]
    k_gathered = k_gathered.permute(0, 1, 3, 2, 4)
    v_gathered = v_gathered.permute(0, 1, 3, 2, 4)
    
    # Attention scores: [B, T, H, k_selected]
    scores = torch.einsum('bqhd,bqhkd->bqhk', q, k_gathered) * scale
    
    # Apply mask: [B, T, k_selected] -> [B, T, 1, k_selected]
    mask_expanded = mask.unsqueeze(2)
    scores = scores.masked_fill(~mask_expanded, float('-inf'))
    
    # Softmax over k_selected dimension
    attn_weights = torch.softmax(scores, dim=-1)
    attn_weights = attn_weights.masked_fill(~mask_expanded, 0.0)
    attn_weights = attn_weights.nan_to_num(0.0)
    
    # Weighted sum: [B, T, H, D]
    output = torch.einsum('bqhk,bqhkd->bqhd', attn_weights, v_gathered)
    
    return output

def recurrent_gated_delta_rule_ref(q, k, v, beta, g):
    """
    Reference recurrent delta rule from NVlabs/GatedDeltaNet.
    Source: lit_gpt/gated_delta_rule_ops/chunk.py -> recurrent_gated_delta_rule_ref()
    
    This is the pure recurrence (no chunking) used for correctness verification.
    
    Args:
        q: [B, H, T, d_k] queries (pre-scaled by d_k^-0.5)
        k: [B, H, T, d_k] keys
        v: [B, H, T, d_v] values
        beta: [B, H, T] write gate (sigmoid output)
        g: [B, H, T] decay gate (log-space, will be exp'd)
    Returns:
        o: [B, H, T, d_v] output
    """
    q, k, v, beta, g = map(lambda x: x.to(torch.float32), [q, k, v, beta, g])
    b, h, l, d_k = q.shape
    d_v = v.shape[-1]
    o = torch.zeros_like(v)
    S = torch.zeros(b, h, d_k, d_v).to(v)
    
    for i in range(l):
        _k = k[:, :, i]
        _q = q[:, :, i]
        _v = v[:, :, i].clone()
        # Decay state
        S = S.clone() * g[:, :, i].exp()[..., None, None]
        beta_i = beta[:, :, i]
        # Delta rule: subtract current memory readout from value
        _v = _v - (S.clone() * _k[..., None]).sum(-2)
        _v = _v * beta_i[..., None]
        # Update state: rank-1 outer product
        S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
        # Output: query the state
        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
    
    return o


In [None]:
# ============================================================================
# Core Components
# Aligned with:
#   GDN: lit_gpt/rmsnorm.py, lit_gpt/rotary.py, lit_gpt/gated_delta_net.py
#   GSA: gsa/attention/rope.py
# ============================================================================

class RMSNorm(nn.Module):
    """Matches lit_gpt/rmsnorm.py"""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return pytorch_rmsnorm(x, self.weight, self.eps)

class FusedRMSNormSwishGate(nn.Module):
    """Matches lit_gpt/gated_delta_net.py -> g_norm_swish_gate"""
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.norm = RMSNorm(dim, eps)

    def forward(self, x, g):
        x_norm = self.norm(x)
        return g * F.silu(x_norm)

def _rotate_half(x):
    """Matches gsa/attention/rope.py -> _rotate_half()"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    """
    Matches gsa/attention/rope.py -> apply_rotary_pos_emb()
    Standard RoPE: x * cos + rotate_half(x) * sin
    """
    q_embed = (q * cos) + (_rotate_half(q) * sin)
    k_embed = (k * cos) + (_rotate_half(k) * sin)
    return q_embed, k_embed

class RotaryEmbedding(nn.Module):
    """
    Matches gsa/attention/rope.py -> RotaryEmbedding
    Standard RoPE with optional NTK scaling.
    """
    def __init__(self, dim: int, max_position_embeddings: int = 8192, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_position_embeddings = max_position_embeddings
        
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # Pre-compute cache
        self._set_cos_sin_cache(max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        t = torch.arange(seq_len, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, seq_len, device, dtype=None):
        if seq_len > self.cos_cached.shape[0]:
            self._set_cos_sin_cache(seq_len)
            self.cos_cached = self.cos_cached.to(device)
            self.sin_cached = self.sin_cached.to(device)
        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]
        if dtype is not None:
            cos = cos.to(dtype)
            sin = sin.to(dtype)
        return cos, sin

def l2_norm_fn(x):
    """L2 normalization matching lit_gpt/gated_delta_net.py -> l2_norm_fn"""
    return F.normalize(x, p=2, dim=-1)

class ShortConvolution(nn.Module):
    """
    Matches lit_gpt/gated_delta_net.py -> ShortConvolution
    Depthwise 1D convolution with optional activation.
    """
    def __init__(self, dim, conv_size=4, activation='silu'):
        super().__init__()
        self.conv_size = conv_size
        self.conv = nn.Conv1d(dim, dim, kernel_size=conv_size, padding=conv_size - 1, groups=dim)
        self.activation = nn.SiLU() if activation == 'silu' else nn.Identity()

    def forward(self, x, mask=None, cache=None):
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x[:, :, :-(self.conv_size - 1)]
        x = x.transpose(1, 2)
        return self.activation(x)


In [None]:
# ============================================================================
# Gated DeltaNet
# Aligned with: https://github.com/NVlabs/GatedDeltaNet/blob/main/lit_gpt/gated_delta_net.py
#
# Key design choices from reference:
#   - expand_k=0.75, expand_v=1.5 (key_dim != value_dim)
#   - Mamba-style gating: gk = -A.exp() * softplus(gk + dt_bias)
#   - L2 normalization on Q, K
#   - q scaled by d_k^(-0.5) inside recurrence (not separate)
#   - Delta rule: S = S*g.exp() + k*(v*beta - (S*k).sum(-2)*beta) [rank-1 update]
#   - FusedRMSNormSwishGate on output with gate projection
# ============================================================================

class GatedDeltaNet(nn.Module):
    def __init__(self, hidden_size, num_heads, head_dim=None, 
                 expand_k=0.75, expand_v=1.5,
                 max_seq_len=262144, conv_size=4, qk_norm='l2',
                 use_mamba_gate=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.qk_norm = qk_norm
        
        # Ref: key_dim = int(hidden_size * expand_k), value_dim = int(hidden_size * expand_v)
        self.key_dim = int(hidden_size * expand_k)
        self.value_dim = int(hidden_size * expand_v)
        self.head_qk_dim = self.key_dim // num_heads
        self.head_v_dim = self.value_dim // num_heads

        # Projections (ref: q_proj, k_proj -> key_dim; v_proj -> value_dim_per_group)
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
        self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)

        # Gate projections (ref: b_proj with bias=True, gk_proj with bias=False when mamba_gate)
        self.b_proj = nn.Linear(hidden_size, num_heads, bias=True)
        self.gk_proj = nn.Linear(hidden_size, num_heads, bias=not use_mamba_gate)

        # Short convolutions (ref: q,k use key_dim; v uses value_dim)
        self.q_conv1d = ShortConvolution(self.key_dim, conv_size=conv_size, activation='silu')
        self.k_conv1d = ShortConvolution(self.key_dim, conv_size=conv_size, activation='silu')
        self.v_conv1d = ShortConvolution(self.value_dim, conv_size=conv_size, activation='silu')

        # Mamba-style gate parameters (ref: A_log, D, dt_bias initialization)
        self.use_mamba_gate = use_mamba_gate
        if use_mamba_gate:
            A = torch.empty(num_heads, dtype=torch.float32).uniform_(0, 16)
            self.A_log = nn.Parameter(torch.log(A))
            self.D = nn.Parameter(torch.ones(num_heads))
            # dt_bias initialization matching ref
            dt_min, dt_max = 0.001, 0.1
            dt = torch.exp(
                torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
            )
            dt = torch.clamp(dt, min=1e-4)
            inv_dt = dt + torch.log(-torch.expm1(-dt))
            self.dt_bias = nn.Parameter(inv_dt)

        # RoPE
        self.rotary_emb = RotaryEmbedding(self.head_qk_dim, max_seq_len)

        # Output norm (ref: FusedRMSNormSwishGate on head_v_dim)
        self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim)

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        device = x.device

        # Step 1: Projections (ref: q_proj, k_proj, v_proj)
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Step 2: Short convolutions (ref: q_conv1d, k_conv1d, v_conv1d with SiLU)
        q = self.q_conv1d(q)
        k = self.k_conv1d(k)
        v = self.v_conv1d(v)

        # Step 3: Reshape to heads
        q = q.view(B, T, self.num_heads, self.head_qk_dim)  # [B, T, H, d_k]
        k = k.view(B, T, self.num_heads, self.head_qk_dim)  # [B, T, H, d_k]
        v = v.view(B, T, self.num_heads, self.head_v_dim)   # [B, T, H, d_v]

        # Step 4: RoPE on Q, K
        cos, sin = self.rotary_emb(T, device, x.dtype)
        cos = cos.unsqueeze(0).unsqueeze(2)  # [1, T, 1, d_k]
        sin = sin.unsqueeze(0).unsqueeze(2)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Step 5: L2 normalize Q, K (ref: qk_norm == 'l2')
        q = l2_norm_fn(q)
        k = l2_norm_fn(k)

        # Step 6: Compute gates
        # beta (write gate): sigmoid (ref: b_proj -> sigmoid -> transpose)
        beta = self.b_proj(x).float().sigmoid()  # [B, T, H]
        beta = beta.transpose(1, 2)  # [B, H, T]

        # gk (decay gate): -A.exp() * softplus(gk + dt_bias) (ref: Mamba-style)
        gk = self.gk_proj(x).float()  # [B, T, H]
        if self.use_mamba_gate:
            gk = -self.A_log.float().exp() * F.softplus(gk + self.dt_bias)
        gk = gk.transpose(1, 2)  # [B, H, T]

        # Step 7: Rearrange for recurrence [B, T, H, d] -> [B, H, T, d]
        q = q.transpose(1, 2)  # [B, H, T, d_k]
        k = k.transpose(1, 2)  # [B, H, T, d_k]
        v = v.transpose(1, 2)  # [B, H, T, d_v]

        # Step 8: Scale queries (ref: q = q * (d_k ** -0.5))
        q = q * (self.head_qk_dim ** -0.5)

        # Step 9: Delta rule recurrence (ref: recurrent_gated_delta_rule_ref)
        o = recurrent_gated_delta_rule_ref(q, k, v, beta, gk)

        # Step 10: Transpose back and apply output norm with gate
        # ref: o = rearrange(o, 'b h l d -> b l h d')
        o = o.transpose(1, 2)  # [B, T, H, d_v]
        g = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim)

        # ref: g_norm_swish_gate(o, g) then reshape
        o_flat = o.reshape(B * T * self.num_heads, self.head_v_dim)
        g_flat = g.reshape(B * T * self.num_heads, self.head_v_dim)
        o_normed = self.g_norm_swish_gate(o_flat, g_flat)
        o = o_normed.view(B, T, self.value_dim)

        # Step 11: Output projection
        return self.o_proj(o)


In [None]:
https://github.com/NVlabs/GatedDeltaNet

In [None]:
# ============================================================================
# Benchmarking Infrastructure
# ============================================================================

def benchmark_run(model_cls, name, configs, device, verbose=True):
    results = []
    
    print(f"Benchmarking {name}...")
    print(f"{'B':<4} {'T':<6} {'D':<6} | {'Time (ms)':<10} | {'Tokens/s':<10} | {'Mem (MB)':<10}")
    print("-" * 60)
    
    for (B, T, D) in configs:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        num_heads = 16
        
        # Instantiate model with reference-aligned constructors
        try:
            if name == "GatedDeltaNet":
                # GDN uses expand_k/expand_v, not explicit head_dim
                model = model_cls(D, num_heads).to(device).to(torch.bfloat16)
            else:
                # GSA with reference defaults: d_indexer=64, k_base=2048, k_min=256, k_max=min(4096,T)
                model = model_cls(D, num_heads, k_max=min(4096, T)).to(device).to(torch.bfloat16)
            
            model.eval()
            x = torch.randn(B, T, D, device=device, dtype=torch.bfloat16)
            
            # Warmup (gradients already globally disabled via torch.set_grad_enabled(False))
            for _ in range(5):
                _ = model(x)
            
            torch.cuda.synchronize()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            
            start_event.record()
            for _ in range(10):
                _ = model(x)
            end_event.record()
            torch.cuda.synchronize()
            
            elapsed_ms = start_event.elapsed_time(end_event) / 10.0
            tokens_per_sec = (B * T) / (elapsed_ms / 1000.0)
            mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
            
            print(f"{B:<4} {T:<6} {D:<6} | {elapsed_ms:<10.2f} | {tokens_per_sec:<10.0f} | {mem_mb:<10.0f}")
            
            results.append({
                "B": B, "T": T, "D": D,
                "time_ms": elapsed_ms,
                "throughput": tokens_per_sec,
                "memory_mb": mem_mb
            })
            
            del model, x
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"{B:<4} {T:<6} {D:<6} | {'OOM':<10} | {'-':<10} | {'-':<10}")
                torch.cuda.empty_cache()
            else:
                print(f"{B:<4} {T:<6} {D:<6} | {'ERROR':<10} | {'-':<10} | {str(e)[:30]}")
    
    return results

def plot_results(results_delta, results_gsa):
    ts_delta = sorted(list(set(r["T"] for r in results_delta)))
    ts_gsa = sorted(list(set(r["T"] for r in results_gsa)))
    ts = sorted(list(set(ts_delta + ts_gsa)))
    
    # Throughput
    tp_delta = [next((r["throughput"] for r in results_delta if r["T"] == t), 0) for t in ts]
    tp_gsa = [next((r["throughput"] for r in results_gsa if r["T"] == t), 0) for t in ts]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].plot(ts, tp_delta, marker='o', label='GatedDeltaNet (NVlabs)')
    axes[0].plot(ts, tp_gsa, marker='x', label='GSA (alfredcs)')
    axes[0].set_xlabel("Sequence Length")
    axes[0].set_ylabel("Tokens / Sec")
    axes[0].set_title("Throughput vs Sequence Length")
    axes[0].legend()
    axes[0].grid(True)
    
    # Memory
    mem_delta = [next((r["memory_mb"] for r in results_delta if r["T"] == t), 0) for t in ts]
    mem_gsa = [next((r["memory_mb"] for r in results_gsa if r["T"] == t), 0) for t in ts]
    
    axes[1].plot(ts, mem_delta, marker='o', label='GatedDeltaNet (NVlabs)')
    axes[1].plot(ts, mem_gsa, marker='x', label='GSA (alfredcs)')
    axes[1].set_xlabel("Sequence Length")
    axes[1].set_ylabel("Peak Memory (MB)")
    axes[1].set_title("Memory vs Sequence Length")
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()


In [None]:

# Run Benchmarks
# Define configurations: [(B, T, D)]
configs = [
    (1, 1024, 2048),
    (1, 2048, 2048),
    (1, 4096, 2048),
    (1, 8192, 2048),
    # Uncomment for larger runs if GPU permits (PyTorch fallback is slow!)
    # (1, 16384, 2048),
]

if torch.cuda.is_available():
    results_delta = benchmark_run(GatedDeltaNet, "GatedDeltaNet", configs, device)
    print("\n")
    results_gsa = benchmark_run(GatedSparseAttention, "GatedSparseAttention", configs, device)
    
    plot_results(results_delta, results_gsa)
else:
    print("Skipping benchmark execution (requires CUDA for timing/memory).")
