# Benchmark: Gated DeltaNet vs Gated Sparse Attention\n
Comparison of throughput and memory usage.

In [None]:

# Imports & Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import matplotlib.pyplot as plt
import numpy as np
import logging
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Benchmark")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mock HAS_TRITON to False for portability (or check if installed)
try:
    import triton
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("Triton not found. Using PyTorch fallbacks.")

# Disable gradients for benchmarking
torch.set_grad_enabled(False)


In [None]:
# ============================================================================
# Kernel Fallbacks (Extracted from src/kernels/)
# ============================================================================

def pytorch_gated_indexer(
    q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, b: torch.Tensor,
    scale: float = 1.0, causal: bool = True, q_offset: int = 0
) -> torch.Tensor:
    """PyTorch fallback for gated indexer computation."""
    batch_size, seq_q, n_heads, d_idx = q.shape
    seq_kv = k.shape[1]

    # Compute QK scores per head: [batch, n_heads, seq_q, seq_kv]
    # Use float32 for stability
    q = q.float()
    k = k.float()
    
    raw_scores = torch.einsum('bqhd,bkd->bhqk', q, k) * scale

    # Add bias: [n_heads, 1, 1]
    bias_expanded = b.float().view(1, -1, 1, 1)

    # Apply sigmoid activation
    gated_scores = torch.sigmoid(raw_scores + bias_expanded)

    # Weight by query-dependent importance
    w_sigmoid = torch.sigmoid(w.float()).permute(0, 2, 1).unsqueeze(-1)

    # Weighted sum across heads
    weighted_scores = gated_scores * w_sigmoid
    final_scores = weighted_scores.sum(dim=1)  # [batch, seq_q, seq_kv]

    # Apply causal mask
    if causal:
        query_positions = q_offset + torch.arange(seq_q, device=q.device)
        key_positions = torch.arange(seq_kv, device=q.device)
        causal_invalid = key_positions.unsqueeze(0) > query_positions.unsqueeze(1)
        final_scores = final_scores.masked_fill(causal_invalid.unsqueeze(0), float('-inf'))

    return final_scores

def pytorch_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
    if residual is not None:
        x = x + residual
    variance = x.pow(2).mean(-1, keepdim=True)
    x_normed = x * torch.rsqrt(variance + eps)
    return x_normed * weight

def _auto_chunk_size(batch_size: int, seq_kv: int, target_bytes: int = 512 * 1024 * 1024) -> int:
    bytes_per_row = batch_size * seq_kv * 4
    if bytes_per_row == 0: return 1
    max_C = target_bytes // bytes_per_row
    C = 1
    while C * 2 <= max_C: C *= 2
    return max(1, min(C, seq_kv))

def _compute_scores(q_chunk, k, w_chunk, b, scale, causal, q_offset, use_triton):
    # Force PyTorch fallback for benchmark portability if triton missing
    return pytorch_gated_indexer(q_chunk, k, w_chunk, b, scale, causal, q_offset)

def fused_indexer_topk(
    q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, b: torch.Tensor,
    scale: float, causal: bool = True, k_base: int = 512, k_min: int = 32, k_max: int = 4096,
    variance_ema: Optional[torch.Tensor] = None, is_training: bool = False, sink_size: int = 4
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Fused indexer + topk (PyTorch chunked version)."""
    batch_size, seq_q, n_heads, d_idx = q.shape
    seq_kv = k.shape[1]
    device = q.device
    
    C = _auto_chunk_size(batch_size, seq_kv)
    
    # Heuristic: single-pass if output fits in memory
    k_limit = min(seq_kv, max(k_max, sink_size))
    
    # Simple chunked implementation
    var_t = torch.empty(batch_size, seq_q, device=device, dtype=torch.float32)
    top_indices = torch.empty(batch_size, seq_q, k_limit, device=device, dtype=torch.int32)

    for q_start in range(0, seq_q, C):
        q_end = min(q_start + C, seq_q)
        q_chunk = q[:, q_start:q_end]
        w_chunk = w[:, q_start:q_end]

        scores = _compute_scores(q_chunk, k, w_chunk, b, scale, causal, q_start, False)

        # FIX BUG 2: Compute variance only over valid (non-masked) entries
        valid_mask = scores != float('-inf')  # [batch, chunk, seq_kv]
        valid_count = valid_mask.sum(dim=-1, keepdim=True).clamp(min=1)  # [batch, chunk, 1]
        scores_zeroed = scores.masked_fill(~valid_mask, 0.0)
        mean_valid = scores_zeroed.sum(dim=-1, keepdim=True) / valid_count
        diff = (scores_zeroed - mean_valid * valid_mask.float())
        diff = diff.masked_fill(~valid_mask, 0.0)
        var_t[:, q_start:q_end] = (diff.pow(2).sum(dim=-1) / valid_count.squeeze(-1)).squeeze(-1)

        # TopK
        if seq_kv > sink_size:
            scores = scores.clone()
            scores[:, :, :sink_size] = float('inf')
        
        _, chunk_idx = scores.topk(k_limit, dim=-1)
        top_indices[:, q_start:q_end, :] = chunk_idx.to(torch.int32)
        del scores, chunk_idx

    # Adaptive k_t computation
    if variance_ema is not None:
        avg_V = variance_ema.clamp(min=1e-6)
    else:
        avg_V = var_t.mean().clamp(min=1e-6)
    k_t = (k_base * var_t / avg_V).floor().clamp(min=k_min, max=k_max).long()

    return var_t, k_t, top_indices

def pytorch_sparse_attention(q, k, v, sparse_idx, sparse_mask, scale):
    """PyTorch fallback for sparse attention gathered by indices.
    
    Args:
        q: [B, T, H, D] query tensor
        k: [B, T, H, D] key tensor
        v: [B, T, H, D] value tensor
        sparse_idx: [B, H, T, K] indices of selected keys per query
        sparse_mask: [B, H, T, K] mask (1=valid, 0=masked)
        scale: attention scale factor
    
    Returns:
        output: [B, T, H, D]
    """
    B, T, H, D = q.shape
    K = sparse_idx.shape[-1]
    
    # Transpose to [B, H, T, D] for head-first layout
    q = q.transpose(1, 2).contiguous()
    k = k.transpose(1, 2).contiguous()
    v = v.transpose(1, 2).contiguous()
    
    # Clamp indices to valid range
    idx = sparse_idx.clamp(0, T - 1).long()  # [B, H, T, K]
    
    # Gather keys and values using advanced indexing
    # idx_exp: [B, H, T, K, D] for gathering from [B, H, T, D]
    idx_exp = idx.unsqueeze(-1).expand(B, H, T, K, D)
    
    # Expand k, v for gather along the sequence dimension (dim=2)
    k_gathered = torch.gather(k.unsqueeze(3).expand(-1, -1, -1, K, -1).contiguous().view(B, H, T * K, D),
                              dim=2,
                              index=idx_exp.contiguous().view(B, H, T * K, D)).view(B, H, T, K, D)
    v_gathered = torch.gather(v.unsqueeze(3).expand(-1, -1, -1, K, -1).contiguous().view(B, H, T * K, D),
                              dim=2,
                              index=idx_exp.contiguous().view(B, H, T * K, D)).view(B, H, T, K, D)
    
    # Compute attention scores: [B, H, T, K]
    scores = torch.einsum('bhtd,bhtkd->bhtk', q, k_gathered) * scale
    
    # Apply sparse mask (0 -> -inf)
    scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
    
    # Softmax over the K dimension
    attn_weights = torch.softmax(scores, dim=-1)
    # Replace NaN from all-masked rows with 0
    attn_weights = attn_weights.nan_to_num(0.0)
    
    # Weighted sum of values: [B, H, T, D]
    out = torch.einsum('bhtk,bhtkd->bhtd', attn_weights, v_gathered)
    
    # Transpose back to [B, T, H, D]
    return out.transpose(1, 2).contiguous()


In [None]:
# ============================================================================
# Core Components (Rotary, Norm, Kronecker)
# ============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return pytorch_rmsnorm(x, self.weight, self.eps)

class FusedRMSNormSwishGate(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.norm = RMSNorm(dim, eps)

    def forward(self, x, g):
        x_norm = self.norm(x)
        return g * F.silu(x_norm)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_position_embeddings: int = 8192, base: int = 10000,
                 original_max_position_embeddings: int = 8192, scaling_factor: float = 32.0):
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_position_embeddings = max_position_embeddings
        self.scaling_factor = scaling_factor
        
        # Simple RoPE setup for benchmark
        scaled_base = base
        if max_position_embeddings > original_max_position_embeddings:
            ext_ratio = max_position_embeddings / original_max_position_embeddings
            scaled_base = base * (ext_ratio ** (dim / (dim - 2)))
        
        inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("mscale", torch.ones(dim // 2)) # Dummy mscale

    def _compute_cos_sin(self, seq_len: int, device, dtype=None):
        t = torch.arange(seq_len, device=device).float()
        freqs = t.unsqueeze(-1) * self.inv_freq.unsqueeze(0)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos_out = emb.cos()
        sin_out = emb.sin()
        if dtype is not None:
            cos_out = cos_out.to(dtype)
            sin_out = sin_out.to(dtype)
        return cos_out, sin_out

    @staticmethod
    def _apply_rotary(x, cos, sin):
        # FIX BUG 4: Proper RoPE with interleaved output
        # Split into first half and second half (paired rotation)
        d = x.shape[-1]
        x1 = x[..., :d//2]
        x2 = x[..., d//2:]
        cos_half = cos[..., :d//2]
        sin_half = sin[..., :d//2]
        # Apply 2D rotation to each pair and interleave back
        o1 = x1 * cos_half - x2 * sin_half
        o2 = x1 * sin_half + x2 * cos_half
        # Stack and interleave: [o1_0, o2_0, o1_1, o2_1, ...]
        return torch.stack((o1, o2), dim=-1).flatten(-2)

class ShortConvolution(nn.Module):
    def __init__(self, dim, conv_size=4, activation='silu'):
        super().__init__()
        self.conv_size = conv_size
        self.conv = nn.Conv1d(dim, dim, kernel_size=conv_size, padding=conv_size - 1, groups=dim)
        self.activation = nn.SiLU() if activation == 'silu' else nn.Identity()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x[:, :, :-(self.conv_size - 1)]
        x = x.transpose(1, 2)
        return self.activation(x)


In [None]:

# ============================================================================
# Gated DeltaNet
# ============================================================================

class GatedDeltaNet(nn.Module):
    def __init__(self, hidden_size, num_heads, head_dim, max_seq_len=262144, 
                 rope_base=10000, rope_original_max=8192, rope_scaling_factor=32.0,
                 conv_size=4, use_output_norm=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.use_output_norm = use_output_norm

        key_dim = num_heads * head_dim
        value_dim = num_heads * head_dim

        self.q_proj = nn.Linear(hidden_size, key_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, key_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, value_dim, bias=False)
        self.g_proj = nn.Linear(hidden_size, value_dim, bias=False)
        self.o_proj = nn.Linear(value_dim, hidden_size, bias=False)

        self.b_proj = nn.Linear(hidden_size, num_heads, bias=True)
        self.gk_proj = nn.Linear(hidden_size, num_heads, bias=True)

        self.q_conv1d = ShortConvolution(key_dim, conv_size=conv_size, activation='silu')
        self.k_conv1d = ShortConvolution(key_dim, conv_size=conv_size, activation='silu')
        self.v_conv1d = ShortConvolution(value_dim, conv_size=conv_size, activation='silu')

        A_init = torch.empty(num_heads).uniform_(0, 16)
        self.A_log = nn.Parameter(torch.log(A_init))
        self.D = nn.Parameter(torch.ones(num_heads))
        self.dt_bias = nn.Parameter(torch.rand(num_heads) * 0.02 - 0.01)

        self.rotary_emb = RotaryEmbedding(head_dim, max_seq_len)

        if use_output_norm:
            self.o_norm = FusedRMSNormSwishGate(head_dim)

    def _delta_rule_python(self, q, k, v, alpha, beta, B, T, device, original_dtype):
        # Transpose
        q_h = q.transpose(1, 2)
        k_h = k.transpose(1, 2)
        v_h = v.transpose(1, 2)
        beta_h = beta.transpose(1, 2)
        alpha_h = alpha.transpose(1, 2)

        S = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim, device=device, dtype=torch.float32)
        outputs = torch.empty(B, self.num_heads, T, self.head_dim, device=device, dtype=torch.float32)
        I = torch.eye(self.head_dim, device=device, dtype=torch.float32).view(1, 1, self.head_dim, self.head_dim)

        # Basic recurrence loop
        for t in range(T):
            q_t = q_h[:, :, t, :].float()
            k_t = k_h[:, :, t, :].float()
            v_t = v_h[:, :, t, :].float()
            beta_t = beta_h[:, :, t, 0].float()
            alpha_t = alpha_h[:, :, t, 0].float()

            o_t = torch.einsum('bhd,bhde->bhe', q_t, S)
            o_t = o_t + self.D.view(1, self.num_heads, 1) * (q_t * k_t).sum(dim=-1, keepdim=True) * v_t
            outputs[:, :, t, :] = o_t

            v_outer = torch.einsum('bhd,bhe->bhde', v_t, k_t)
            k_outer = torch.einsum('bhd,bhe->bhde', k_t, k_t)
            alpha_t = alpha_t.view(B, self.num_heads, 1, 1)
            beta_t = beta_t.view(B, self.num_heads, 1, 1)

            orthogonal_proj = I - beta_t * k_outer
            S = alpha_t * torch.einsum('bhde,bhef->bhdf', S, orthogonal_proj) + beta_t * v_outer

        return outputs.to(original_dtype).transpose(1, 2)

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        device = x.device

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        q = self.q_conv1d(q)
        k = self.k_conv1d(k)
        v = self.v_conv1d(v)

        q = q.view(B, T, self.num_heads, self.head_dim)
        k = k.view(B, T, self.num_heads, self.head_dim)
        v = v.view(B, T, self.num_heads, self.head_dim)
        g = g.view(B, T, self.num_heads, self.head_dim)

        cos, sin = self.rotary_emb._compute_cos_sin(T, device, x.dtype)
        cos = cos.unsqueeze(0).unsqueeze(2)
        sin = sin.unsqueeze(0).unsqueeze(2)
        q = self.rotary_emb._apply_rotary(q, cos, sin)
        k = self.rotary_emb._apply_rotary(k, cos, sin)

        q = F.normalize(q, p=2, dim=-1)
        k = F.normalize(k, p=2, dim=-1)

        beta = torch.sigmoid(self.b_proj(x)).unsqueeze(-1)
        gk = self.gk_proj(x)
        A = torch.exp(self.A_log)
        alpha = torch.exp(-A.view(1, 1, self.num_heads, 1) * F.softplus(gk + self.dt_bias).unsqueeze(-1))

        # Use Python fallback if no Triton/fla
        o = self._delta_rule_python(q, k, v, alpha, beta, B, T, device, x.dtype)

        if self.use_output_norm:
            o_flat = o.reshape(B * T * self.num_heads, self.head_dim)
            g_flat = g.reshape(B * T * self.num_heads, self.head_dim)
            o_normed = self.o_norm(o_flat, g_flat)
            o = o_normed.view(B, T, self.num_heads, self.head_dim)
        
        o = o.reshape(B, T, self.num_heads * self.head_dim)
        return self.o_proj(o)


In [None]:
# ============================================================================
# Gated Sparse Attention (GSA)
# ============================================================================

class GatedSparseAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, max_seq_len=262144, 
                 k_base=512, k_min=32, k_max=1024, indexer_heads=4):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.k_base = k_base
        self.k_min = k_min
        self.k_max = k_max
        self.indexer_heads = indexer_heads
        
        self.d_idx = 32
        self.W_Iq = nn.Linear(hidden_size, indexer_heads * self.d_idx, bias=False)
        self.W_Ik = nn.Linear(hidden_size, self.d_idx, bias=False)
        self.W_Iw = nn.Linear(hidden_size, indexer_heads, bias=False)
        self.gate_bias = nn.Parameter(torch.zeros(indexer_heads))
        
        self.register_buffer("variance_ema", torch.tensor(1.0))
        self.ema_momentum = 0.1
        
        self.W_q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_v = nn.Linear(hidden_size, hidden_size, bias=False)
        self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        
        self.W_gv = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_go = nn.Linear(hidden_size, hidden_size, bias=False)
        
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        device = x.device
        
        # Indexer
        q_I = self.W_Iq(x).view(B, T, self.indexer_heads, self.d_idx)
        k_I = self.W_Ik(x)
        w_raw = self.W_Iw(x)
        scale_idx = 1.0 / math.sqrt(self.d_idx)
        
        # FIX BUG 5: Pass variance_ema to fused_indexer_topk
        var_t, k_t, top_indices = fused_indexer_topk(
            q=q_I, k=k_I, w=w_raw, b=self.gate_bias, scale=scale_idx, 
            causal=True, k_base=self.k_base, k_min=self.k_min, k_max=self.k_max,
            variance_ema=self.variance_ema, is_training=self.training
        )
        
        # Update EMA during training
        if self.training:
            with torch.no_grad():
                batch_var_mean = var_t.mean()
                self.variance_ema.mul_(1.0 - self.ema_momentum).add_(self.ema_momentum * batch_var_mean)
        
        k_limit = top_indices.size(-1)
        base_idx = top_indices.long()
        range_k = torch.arange(k_limit, device=device)
        keep_mask = range_k.view(1, 1, -1) < k_t.unsqueeze(-1)
        
        # Attention
        q = self.W_q(x)
        k_attn = self.W_k(x)
        v = self.W_v(x)
        
        g_v = torch.sigmoid(self.W_gv(x))
        v = v * g_v
        
        q = q.view(B, T, self.num_heads, self.head_dim)
        k_attn = k_attn.view(B, T, self.num_heads, self.head_dim)
        v = v.view(B, T, self.num_heads, self.head_dim)
        
        cos, sin = self.rotary_emb._compute_cos_sin(T, device, x.dtype)
        cos = cos.unsqueeze(0).unsqueeze(2)
        sin = sin.unsqueeze(0).unsqueeze(2)
        q = self.rotary_emb._apply_rotary(q, cos, sin)
        k_attn = self.rotary_emb._apply_rotary(k_attn, cos, sin)
        
        # Sparse Attention Call
        sparse_idx = base_idx.unsqueeze(1).expand(B, self.num_heads, T, k_limit)
        sparse_mask = keep_mask.float().unsqueeze(1).expand(B, self.num_heads, T, k_limit)
        scale_attn = 1.0 / math.sqrt(self.head_dim)
        
        o_sparse = pytorch_sparse_attention(q, k_attn, v, sparse_idx, sparse_mask, scale_attn)
        o_sparse = o_sparse.contiguous().view(B, T, self.hidden_size)
        
        g_o = torch.sigmoid(self.W_go(x))
        return self.o_proj(o_sparse * g_o)


In [None]:
# ============================================================================
# Benchmarking Infrastructure
# ============================================================================

def benchmark_run(model_cls, name, configs, device, verbose=True):
    results = []
    
    print(f"Benchmarking {name}...")
    print(f"{'B':<4} {'T':<6} {'D':<6} | {'Time (ms)':<10} | {'Tokens/s':<10} | {'Mem (MB)':<10}")
    print("-" * 60)
    
    for (B, T, D) in configs:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        num_heads = 16
        head_dim = D // num_heads
        
        # Instantiate model
        try:
            if name == "GatedDeltaNet":
                model = model_cls(D, num_heads, head_dim).to(device).to(torch.bfloat16)
            else:
                model = model_cls(D, num_heads).to(device).to(torch.bfloat16)
            
            model.eval()
            x = torch.randn(B, T, D, device=device, dtype=torch.bfloat16)
            
            # Warmup (gradients already globally disabled via torch.set_grad_enabled(False))
            for _ in range(5):
                _ = model(x)
            
            torch.cuda.synchronize()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            
            start_event.record()
            for _ in range(10):
                _ = model(x)
            end_event.record()
            torch.cuda.synchronize()
            
            elapsed_ms = start_event.elapsed_time(end_event) / 10.0
            tokens_per_sec = (B * T) / (elapsed_ms / 1000.0)
            mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
            
            print(f"{B:<4} {T:<6} {D:<6} | {elapsed_ms:<10.2f} | {tokens_per_sec:<10.0f} | {mem_mb:<10.0f}")
            
            results.append({
                "B": B, "T": T, "D": D,
                "time_ms": elapsed_ms,
                "throughput": tokens_per_sec,
                "memory_mb": mem_mb
            })
            
            del model, x
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"{B:<4} {T:<6} {D:<6} | {'OOM':<10} | {'-':<10} | {'-':<10}")
            else:
                print(f"{B:<4} {T:<6} {D:<6} | {'ERROR':<10} | {'-':<10} | {str(e)[:10]}")
    
    return results

def plot_results(results_delta, results_gsa):
    # Extract data for T scaling (assuming B=1, D=fixed)
    ts = sorted(list(set(r["T"] for r in results_delta)))
    
    # Throughput
    tp_delta = [next((r["throughput"] for r in results_delta if r["T"] == t), 0) for t in ts]
    tp_gsa = [next((r["throughput"] for r in results_gsa if r["T"] == t), 0) for t in ts]
    
    plt.figure(figsize=(10, 5))
    plt.plot(ts, tp_delta, marker='o', label='DeltaNet')
    plt.plot(ts, tp_gsa, marker='x', label='GSA')
    plt.xlabel("Sequence Length")
    plt.ylabel("Tokens / Sec")
    plt.title("Throughput vs Sequence Length")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Memory
    mem_delta = [next((r["memory_mb"] for r in results_delta if r["T"] == t), 0) for t in ts]
    mem_gsa = [next((r["memory_mb"] for r in results_gsa if r["T"] == t), 0) for t in ts]
    
    plt.figure(figsize=(10, 5))
    plt.plot(ts, mem_delta, marker='o', label='DeltaNet')
    plt.plot(ts, mem_gsa, marker='x', label='GSA')
    plt.xlabel("Sequence Length")
    plt.ylabel("Peak Memory (MB)")
    plt.title("Memory vs Sequence Length")
    plt.legend()
    plt.grid(True)
    plt.show()


In [None]:

# Run Benchmarks
# Define configurations: [(B, T, D)]
configs = [
    (1, 1024, 2048),
    (1, 2048, 2048),
    (1, 4096, 2048),
    (1, 8192, 2048),
    # Uncomment for larger runs if GPU permits (PyTorch fallback is slow!)
    # (1, 16384, 2048),
]

if torch.cuda.is_available():
    results_delta = benchmark_run(GatedDeltaNet, "GatedDeltaNet", configs, device)
    print("\n")
    results_gsa = benchmark_run(GatedSparseAttention, "GatedSparseAttention", configs, device)
    
    plot_results(results_delta, results_gsa)
else:
    print("Skipping benchmark execution (requires CUDA for timing/memory).")
