In [1]:
import torch

def lin_kv_block_memory_efficient(
    X, W_Q, W_QK, W_KK, W_VK, W_QV, W_KV, W_VV, eps=1e-6
):
    """
    Memory-efficient (non-parallelized) Lin-KV Block.
    - Single Python loop over t
    - O(d^2) extra state via prefix accumulators
    - No O(n^2) stacks kept

    Args:
      X:   (n, d)
      W_*: (d, d)

    Returns:
      Y: (n, d)
    """
    n, d = X.shape
    device, dtype = X.device, X.dtype

    # Projections
    Q  = X @ W_Q
    QK = X @ W_QK
    KK = X @ W_KK
    VK = X @ W_VK
    QV = X @ W_QV
    KV = X @ W_KV
    VV = X @ W_VV

    # Prefix accumulators
    S_K = torch.zeros(d, d, device=device, dtype=dtype)
    Z_K = torch.zeros(d,     device=device, dtype=dtype)
    S_V = torch.zeros(d, d, device=device, dtype=dtype)
    Z_V = torch.zeros(d,     device=device, dtype=dtype)

    Y = torch.empty(n, d, device=device, dtype=dtype)

    for t in range(n):
        # Update prefix sums with step t
        S_K = S_K + torch.outer(KK[t], VK[t])  # (d,d)
        Z_K = Z_K + KK[t]                       # (d,)
        S_V = S_V + torch.outer(KV[t], VV[t])  # (d,d)
        Z_V = Z_V + KV[t]                       # (d,)

        # Build effective K_(t) and V_(t) for prefix 1..t (no storage beyond this step)
        QK_prefix = QK[:t+1]                    # (t+1, d)
        QV_prefix = QV[:t+1]                    # (t+1, d)

        K_num = QK_prefix @ S_K                 # (t+1, d)
        V_num = QV_prefix @ S_V                 # (t+1, d)

        K_den = (QK_prefix @ Z_K).unsqueeze(-1) + eps  # (t+1,1)
        V_den = (QV_prefix @ Z_V).unsqueeze(-1) + eps  # (t+1,1)

        K_t = K_num / K_den                     # (t+1, d)
        V_t = V_num / V_den                     # (t+1, d)

        # Attention for position t over the prefix
        logits = (Q[t] @ K_t.T).unsqueeze(0)    # (1, t+1)
        attn = torch.softmax(logits, dim=-1)    # (1, t+1)

        # Output y_t
        Y[t:t+1] = attn @ V_t                   # (1, d)

    return Y


# ---------- Vectorized (parallelized) reference for sanity checks ----------
def lin_kv_block_vectorized(X, W_Q, W_QK, W_KK, W_VK, W_QV, W_KV, W_VV, eps=1e-6):
    """
    Parallelized / vectorized O(n^2) time & O(n^2) memory Lin-KV Block.
    Returns Y only (no quadratic stacks exposed).
    """
    n, d = X.shape
    device, dtype = X.device, X.dtype

    # Projections
    Q  = X @ W_Q
    QK = X @ W_QK
    KK = X @ W_KK
    VK = X @ W_VK
    QV = X @ W_QV
    KV = X @ W_KV
    VV = X @ W_VV

    # Prefix accumulators for all t
    outer_K = torch.einsum("nd,ne->nde", KK, VK)   # (n,d,d)
    S_pref_K = outer_K.cumsum(dim=0)               # (n,d,d)
    Z_pref_K = KK.cumsum(dim=0)                    # (n,d)

    outer_V = torch.einsum("nd,ne->nde", KV, VV)   # (n,d,d)
    S_pref_V = outer_V.cumsum(dim=0)               # (n,d,d)
    Z_pref_V = KV.cumsum(dim=0)                    # (n,d)

    # Build all effective K_(t) and V_(t)
    K_num_all = torch.einsum("id,tdk->tik", QK, S_pref_K)  # (n,n,d)
    K_den_all = torch.einsum("id,td->ti",  QK, Z_pref_K)   # (n,n)
    K_all     = K_num_all / (K_den_all[..., None] + eps)   # (n,n,d)

    V_num_all = torch.einsum("id,tdk->tik", QV, S_pref_V)  # (n,n,d)
    V_den_all = torch.einsum("id,td->ti",  QV, Z_pref_V)   # (n,n)
    V_all     = V_num_all / (V_den_all[..., None] + eps)   # (n,n,d)

    # Mask to keep i<=t
    tri = torch.tril(torch.ones(n, n, device=device, dtype=torch.bool))
    logits = torch.einsum("td,tid->ti", Q, K_all)          # (n,n)
    logits = logits.masked_fill(~tri, float("-inf"))

    A = torch.softmax(logits, dim=1)                       # (n,n)
    Y = torch.einsum("ti,tid->td", A, V_all)               # (n,d)
    return Y





In [3]:
# ---------- Quick sanity check ----------
# if __name__ == "__main__":
# torch.manual_seed(0)
n, d = 8, 4
X  = torch.randn(n, d)
Ws = [torch.randn(d, d) for _ in range(7)]

Y_mem = lin_kv_block_memory_efficient(X, *Ws)
Y_vec = lin_kv_block_vectorized(X, *Ws)

max_diff = (Y_mem - Y_vec).abs().max().item()
print("max |Y_mem - Y_vec|:", max_diff)

# Optional strict check
assert torch.allclose(Y_mem, Y_vec, atol=1e-5, rtol=1e-5), "Mismatch!"
print("✔ Outputs match within tolerance.")

max |Y_mem - Y_vec|: 4.291534423828125e-06
✔ Outputs match within tolerance.


# multihead version

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinKVBlockMultiHead(nn.Module):
    """
    Multi-head, naive parallelized Lin-KV block.
    Complexity: O(B * H * T^2 * d_h) time and memory (quadratic in sequence length).

    Args:
        d_model: embedding dimension (must equal num_heads * head_dim)
        num_heads: number of heads (H)
        head_dim: per-head dimension (d_h)
        eps: small constant for numerical stability in denominators
    """
    def __init__(self, d_model: int, num_heads: int, head_dim: int, eps: float = 1e-6):
        super().__init__()
        assert d_model == num_heads * head_dim, \
            "d_model must equal num_heads * head_dim for simple concat output."
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.eps = eps

        D = d_model
        H = num_heads * head_dim  # flattened heads

        # 7 projection matrices as learnable parameters (no bias, to match pseudocode)
        self.W_Q  = nn.Parameter(torch.empty(D, H))
        self.W_QK = nn.Parameter(torch.empty(D, H))
        self.W_KK = nn.Parameter(torch.empty(D, H))
        self.W_VK = nn.Parameter(torch.empty(D, H))
        self.W_QV = nn.Parameter(torch.empty(D, H))
        self.W_KV = nn.Parameter(torch.empty(D, H))
        self.W_VV = nn.Parameter(torch.empty(D, H))

        self.reset_parameters()

    def reset_parameters(self):
        for p in [self.W_Q, self.W_QK, self.W_KK, self.W_VK,
                  self.W_QV, self.W_KV, self.W_VV]:
            nn.init.xavier_uniform_(p)

    def _proj(self, X, W):
        # X: (B, T, D) @ (D, H) -> (B, T, H) -> (B, T, num_heads, head_dim)
        B, T, _ = X.shape
        out = X @ W                           # (B, T, H_flat)
        out = out.view(B, T, self.num_heads, self.head_dim)
        return out

    def forward(self, X):
        """
        X: (B, T, d_model) or (T, d_model)
        Returns:
            Y: (B, T, d_model) or (T, d_model) matching input rank.
        """
        squeeze_batch = False
        if X.dim() == 2:
            X = X.unsqueeze(0)
            squeeze_batch = True
        B, T, D = X.shape
        Hh = self.num_heads
        Dh = self.head_dim
        device = X.device

        # 1) Projections -> (B, T, Hh, Dh)
        Q  = self._proj(X, self.W_Q)
        QK = self._proj(X, self.W_QK)
        KK = self._proj(X, self.W_KK)
        VK = self._proj(X, self.W_VK)
        QV = self._proj(X, self.W_QV)
        KV = self._proj(X, self.W_KV)
        VV = self._proj(X, self.W_VV)

        # 2) Prefix accumulators for all t (vectorized over time)
        #    S_t^K = sum_{i<=t} KK_i ⊗ VK_i  ; Z_t^K = sum_{i<=t} KK_i
        #    S_t^V = sum_{i<=t} KV_i ⊗ VV_i  ; Z_t^V = sum_{i<=t} KV_i
        # Shapes:
        #   KK, VK: (B, T, Hh, Dh)
        #   outer_K: (B, T, Hh, Dh, Dh)
        outer_K = torch.einsum('bthd,bthe->bthde', KK, VK)
        S_pref_K = outer_K.cumsum(dim=1)             # (B, T, Hh, Dh, Dh)
        Z_pref_K = KK.cumsum(dim=1)                  # (B, T, Hh, Dh)

        outer_V = torch.einsum('bthd,bthe->bthde', KV, VV)
        S_pref_V = outer_V.cumsum(dim=1)             # (B, T, Hh, Dh, Dh)
        Z_pref_V = KV.cumsum(dim=1)                  # (B, T, Hh, Dh)

        # 3) Build all effective K_(t) and V_(t) for all prefixes in parallel
        # Indices:
        #   t: target time, i: prefix index (i <= t)
        # K_num[b,t,i,h,:] = QK[b,i,h,:] @ S_pref_K[b,t,h,:,:]
        K_num_all = torch.einsum('bihd,bthde->btihe', QK, S_pref_K)  # (B, T, T, Hh, Dh)
        K_den_all = torch.einsum('bihd,bthd->btih',  QK, Z_pref_K)   # (B, T, T, Hh)
        K_all = K_num_all / (K_den_all[..., None] + self.eps)        # (B, T, T, Hh, Dh)

        V_num_all = torch.einsum('bihd,bthde->btihe', QV, S_pref_V)  # (B, T, T, Hh, Dh)
        V_den_all = torch.einsum('bihd,bthd->btih',  QV, Z_pref_V)   # (B, T, T, Hh)
        V_all = V_num_all / (V_den_all[..., None] + self.eps)        # (B, T, T, Hh, Dh)

        # 4) Attention logits and outputs
        # logits[b,t,i,h] = Q[b,t,h,:] · K_all[b,t,i,h,:]
        logits = torch.einsum('bthd,btihd->btih', Q, K_all)          # (B, T, T, Hh)

        # Causal mask (i <= t)
        tri = torch.tril(torch.ones(T, T, device=device, dtype=torch.bool))
        logits = logits.masked_fill(~tri.view(1, T, T, 1), float('-inf'))

        A = F.softmax(logits, dim=2)                                 # (B, T, T, Hh)

        # y_head[b,t,h,:] = sum_i A[b,t,i,h] * V_all[b,t,i,h,:]
        Y_heads = torch.einsum('btih,btihd->bthd', A, V_all)         # (B, T, Hh, Dh)

        # Concat heads -> (B, T, d_model)
        Y = Y_heads.reshape(B, T, Hh * Dh)

        return Y.squeeze(0) if squeeze_batch else Y


In [5]:
# Single-head reference from earlier (returns Y only)
def lin_kv_block_vectorized_singlehead(X, W_Q, W_QK, W_KK, W_VK, W_QV, W_KV, W_VV, eps=1e-6):
    n, d = X.shape
    Q  = X @ W_Q
    QK = X @ W_QK
    KK = X @ W_KK
    VK = X @ W_VK
    QV = X @ W_QV
    KV = X @ W_KV
    VV = X @ W_VV

    outer_K = torch.einsum("nd,ne->nde", KK, VK)
    S_pref_K = outer_K.cumsum(dim=0)
    Z_pref_K = KK.cumsum(dim=0)

    outer_V = torch.einsum("nd,ne->nde", KV, VV)
    S_pref_V = outer_V.cumsum(dim=0)
    Z_pref_V = KV.cumsum(dim=0)

    K_num_all = torch.einsum("id,tdk->tik", QK, S_pref_K)
    K_den_all = torch.einsum("id,td->ti",  QK, Z_pref_K)
    K_all     = K_num_all / (K_den_all[..., None] + eps)

    V_num_all = torch.einsum("id,tdk->tik", QV, S_pref_V)
    V_den_all = torch.einsum("id,td->ti",  QV, Z_pref_V)
    V_all     = V_num_all / (V_den_all[..., None] + eps)

    tri = torch.tril(torch.ones(n, n, dtype=torch.bool, device=X.device))
    logits = torch.einsum("td,tid->ti", Q, K_all).masked_fill(~tri, float("-inf"))
    A = torch.softmax(logits, dim=1)
    Y = torch.einsum("ti,tid->td", A, V_all)
    return Y

# --- test ---
torch.manual_seed(0)
B, T, d = 2, 6, 8
X = torch.randn(B, T, d)

# Build multi-head module with 1 head of size d
m = LinKVBlockMultiHead(d_model=d, num_heads=1, head_dim=d)
with torch.no_grad():
    # Copy its weights out to feed the single-head function
    W_Q, W_QK, W_KK, W_VK, W_QV, W_KV, W_VV = (
        m.W_Q, m.W_QK, m.W_KK, m.W_VK, m.W_QV, m.W_KV, m.W_VV
    )

# Compare per batch item
out_m = m(X)                           # (B, T, d)
max_diffs = []
for b in range(B):
    y_ref = lin_kv_block_vectorized_singlehead(
        X[b], W_Q, W_QK, W_KK, W_VK, W_QV, W_KV, W_VV
    )
    max_diffs.append((out_m[b] - y_ref).abs().max().item())

print("max per-batch |diff|:", max_diffs)
assert all(md < 1e-5 for md in max_diffs)
print("✔ Multi-head (H=1) matches single-head reference.")


max per-batch |diff|: [2.384185791015625e-07, 1.1920928955078125e-07]
✔ Multi-head (H=1) matches single-head reference.


In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinKVBlockMultiHeadLinearMem(nn.Module):
    """
    Multi-head Lin-KV with O(B * H * T * d_h) memory using a FlashAttention-style
    streaming softmax over prefix tiles (still O(T^2) time).

    Args:
        d_model: embedding dimension (must equal num_heads * head_dim)
        num_heads: number of heads
        head_dim: per-head dim
        tile_size: prefix tile size
        eps: small constant for denominators
        accum_dtype: dtype used for softmax accumulators (default float32; set to float64 for tests)
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        head_dim: int,
        tile_size: int = 64,
        eps: float = 1e-6,
        accum_dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        assert d_model == num_heads * head_dim, "d_model must equal num_heads * head_dim."
        self.d_model, self.num_heads, self.head_dim = d_model, num_heads, head_dim
        self.tile_size, self.eps, self.accum_dtype = tile_size, eps, accum_dtype

        Hflat = num_heads * head_dim
        self.W_Q  = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_QK = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_KK = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_VK = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_QV = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_KV = nn.Parameter(torch.empty(d_model, Hflat))
        self.W_VV = nn.Parameter(torch.empty(d_model, Hflat))
        self.reset_parameters()

    def reset_parameters(self):
        for p in [self.W_Q, self.W_QK, self.W_KK, self.W_VK, self.W_QV, self.W_KV, self.W_VV]:
            nn.init.xavier_uniform_(p)

    def _proj(self, X, W):
        # X: (B,T,D) @ (D,Hflat) -> (B,T,H,dh)
        B, T, _ = X.shape
        out = X @ W
        return out.view(B, T, self.num_heads, self.head_dim)

    @torch.no_grad()
    def _promote_dtype(self, X: torch.Tensor) -> torch.Tensor:
        # compute in at least fp32 for stability
        return X.float() if X.dtype in (torch.float16, torch.bfloat16) else X

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        squeeze = False
        if X.dim() == 2:
            X = X.unsqueeze(0)
            squeeze = True

        X = self._promote_dtype(X)
        B, T, D = X.shape
        H, Dh = self.num_heads, self.head_dim
        dev, dtype = X.device, X.dtype
        ts = self.tile_size

        # 1) projections
        Q  = self._proj(X, self.W_Q)
        QK = self._proj(X, self.W_QK)
        KK = self._proj(X, self.W_KK)
        VK = self._proj(X, self.W_VK)
        QV = self._proj(X, self.W_QV)
        KV = self._proj(X, self.W_KV)
        VV = self._proj(X, self.W_VV)

        # 2) prefix states
        S_K = torch.zeros(B, H, Dh, Dh, device=dev, dtype=dtype)
        Z_K = torch.zeros(B, H, Dh,    device=dev, dtype=dtype)
        S_V = torch.zeros(B, H, Dh, Dh, device=dev, dtype=dtype)
        Z_V = torch.zeros(B, H, Dh,    device=dev, dtype=dtype)

        Y = torch.empty(B, T, H, Dh, device=dev, dtype=dtype)

        # helpers for (B,M,H,Dh) x (B,H,Dh,Dh)
        def bmh_bmm(A, Bmat):
            A2 = A.permute(0, 2, 1, 3).contiguous().view(B * H, A.size(1), Dh)
            B2 = Bmat.view(B * H, Dh, Dh)
            out = torch.bmm(A2, B2)
            return out.view(B, H, A.size(1), Dh).permute(0, 2, 1, 3).contiguous()

        def bmh_dot(A, Bvec):
            return torch.einsum('bmhd,bhd->bmh', A, Bvec)

        neg_inf = torch.finfo(self.accum_dtype).min

        for t in range(T):
            # update prefix with step t
            KK_t, VK_t = KK[:, t], VK[:, t]     # (B,H,Dh)
            KV_t, VV_t = KV[:, t], VV[:, t]
            S_K = S_K + torch.einsum('bhd,bhe->bhde', KK_t, VK_t)
            Z_K = Z_K + KK_t
            S_V = S_V + torch.einsum('bhd,bhe->bhde', KV_t, VV_t)
            Z_V = Z_V + KV_t

            q_t = Q[:, t]  # (B,H,Dh)

            # online softmax accumulators (use accum_dtype for stability)
            m = torch.full((B, H, 1), neg_inf, device=dev, dtype=self.accum_dtype)
            l = torch.zeros((B, H, 1), device=dev, dtype=self.accum_dtype)
            y = torch.zeros((B, H, Dh), device=dev, dtype=self.accum_dtype)

            i_end = t + 1
            for i0 in range(0, i_end, ts):
                i1 = min(i0 + ts, i_end)
                M = i1 - i0

                QK_blk = QK[:, i0:i1]  # (B,M,H,Dh)
                QV_blk = QV[:, i0:i1]

                # effective K_(t)[i] and V_(t)[i]
                K_num = bmh_bmm(QK_blk, S_K)                       # (B,M,H,Dh)
                K_den = bmh_dot(QK_blk, Z_K)[..., None] + self.eps # (B,M,H,1)
                K_blk = K_num / K_den

                V_num = bmh_bmm(QV_blk, S_V)
                V_den = bmh_dot(QV_blk, Z_V)[..., None] + self.eps
                V_blk = V_num / V_den

                # logits in accum dtype
                logits_blk = torch.einsum('bhd,bmhd->bhm', q_t, K_blk).to(self.accum_dtype)

                # --- canonical merge: compute block with its own max ---
                m_blk = logits_blk.max(dim=2, keepdim=True).values                 # (B,H,1)
                # exp(logits - m_blk)
                w_blk = torch.exp(logits_blk - m_blk)                              # (B,H,M)
                # l_blk, y_blk under m_blk
                l_blk = w_blk.sum(dim=2, keepdim=True)                             # (B,H,1)
                y_blk = torch.einsum('bhm,bmhd->bhd', w_blk, V_blk.to(self.accum_dtype))

                # merge with running stats
                m_new = torch.maximum(m, m_blk)                                    # (B,H,1)
                alpha = torch.exp(m - m_new)                                       # (B,H,1)
                beta  = torch.exp(m_blk - m_new)                                   # (B,H,1)

                l = l * alpha + l_blk * beta                                       # (B,H,1)
                y = y * alpha + y_blk * beta                                       # (B,H,Dh)
                m = m_new

            Y[:, t] = (y / l).to(dtype)  # back to compute dtype

        return (Y.reshape(B, T, H * Dh)).squeeze(0) if squeeze else Y.reshape(B, T, H * Dh)


In [69]:
# Naive quadratic-memory version from earlier (unchanged)
# class LinKVBlockMultiHeadNaive(...): ...

# torch.manual_seed(0)
B, T, H, Dh = 2, 24, 3, 8
D = H * Dh
X = torch.randn(B, T, D)

naive = LinKVBlockMultiHeadNaive(d_model=D, num_heads=H, head_dim=Dh)
flash = LinKVBlockMultiHeadLinearMem(d_model=D, num_heads=H, head_dim=Dh,
                                     tile_size=7, accum_dtype=torch.float64)

# Share weights
with torch.no_grad():
    for n_param, f_param in [
        (naive.W_Q, flash.W_Q), (naive.W_QK, flash.W_QK),
        (naive.W_KK, flash.W_KK), (naive.W_VK, flash.W_VK),
        (naive.W_QV, flash.W_QV), (naive.W_KV, flash.W_KV),
        (naive.W_VV, flash.W_VV),
    ]:
        f_param.copy_(n_param)

Y_naive = naive(X)
Y_flash = flash(X)

diff = (Y_naive - Y_flash).abs().max().item()
print("max |naive - linear-mem|:", diff)
assert torch.allclose(Y_naive, Y_flash, atol=1e-6, rtol=1e-6)
print("✔ match within tolerance")


max |naive - linear-mem|: 1.6328125


AssertionError: 

# more testtt

In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinKVBlockMultiHeadFlash(nn.Module):
    """
    Multi-head Lin-KV block using FlashAttention-style streaming softmax.
    Complexity: O(B * H * T^2 * d_h) time, O(B * (T * H * d_h + H * d_h^2)) memory (linear in T).
    Args:
        d_model: embedding dimension (must equal num_heads * head_dim)
        num_heads: number of heads (H)
        head_dim: per-head dimension (d_h)
        block_size: prefix block size for streaming scan over i<=t
        eps: small constant for numerical stability
        accum_dtype: dtype to use for prefix accumulators (None => use input dtype)
    """
    def __init__(self, d_model: int, num_heads: int, head_dim: int,
                 block_size: int = 128, eps: float = 1e-6, accum_dtype=None):
        super().__init__()
        assert d_model == num_heads * head_dim, \
            "d_model must equal num_heads * head_dim for simple concat output."
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_size = block_size
        self.eps = eps
        self.accum_dtype = accum_dtype

        D = d_model
        H = num_heads * head_dim  # flattened heads

        # 7 projection matrices (no bias, to match your pseudocode)
        self.W_Q  = nn.Parameter(torch.empty(D, H))
        self.W_QK = nn.Parameter(torch.empty(D, H))
        self.W_KK = nn.Parameter(torch.empty(D, H))
        self.W_VK = nn.Parameter(torch.empty(D, H))
        self.W_QV = nn.Parameter(torch.empty(D, H))
        self.W_KV = nn.Parameter(torch.empty(D, H))
        self.W_VV = nn.Parameter(torch.empty(D, H))

        self.reset_parameters()

    def reset_parameters(self):
        for p in [self.W_Q, self.W_QK, self.W_KK, self.W_VK,
                  self.W_QV, self.W_KV, self.W_VV]:
            nn.init.xavier_uniform_(p)

    @torch.no_grad()
    def load_from(self, other_module):
        """Convenience: copy weights from your naive implementation."""
        with torch.no_grad():
            for name in ["W_Q", "W_QK", "W_KK", "W_VK", "W_QV", "W_KV", "W_VV"]:
                getattr(self, name).copy_(getattr(other_module, name))
        return self

    def _proj(self, X, W):
        # X: (B, T, D) @ (D, H_flat) -> (B, T, H_flat) -> (B, T, Hh, Dh)
        B, T, _ = X.shape
        Hh, Dh = self.num_heads, self.head_dim
        out = X @ W                    # (B, T, H_flat)
        out = out.view(B, T, Hh, Dh)   # (B, T, Hh, Dh)
        return out

    def forward(self, X):
        """
        X: (B, T, d_model) or (T, d_model)
        Returns: (B, T, d_model) or (T, d_model) matching input rank.
        """
        squeeze_batch = False
        if X.dim() == 2:
            X = X.unsqueeze(0)
            squeeze_batch = True

        B, T, D = X.shape
        Hh, Dh = self.num_heads, self.head_dim
        device = X.device
        dtype = X.dtype
        adtype = self.accum_dtype or dtype  # accumulators dtype

        # ---- 1) Projections (linear memory in T) ----
        Q  = self._proj(X, self.W_Q)   # (B, T, Hh, Dh)
        QK = self._proj(X, self.W_QK)
        KK = self._proj(X, self.W_KK)
        VK = self._proj(X, self.W_VK)
        QV = self._proj(X, self.W_QV)
        KV = self._proj(X, self.W_KV)
        VV = self._proj(X, self.W_VV)

        # ---- 2) Prefix accumulators (maintained online) ----
        # S^K, S^V: (B, Hh, Dh, Dh); Z^K, Z^V: (B, Hh, Dh)
        S_K = torch.zeros(B, Hh, Dh, Dh, device=device, dtype=adtype)
        Z_K = torch.zeros(B, Hh, Dh,     device=device, dtype=adtype)
        S_V = torch.zeros(B, Hh, Dh, Dh, device=device, dtype=adtype)
        Z_V = torch.zeros(B, Hh, Dh,     device=device, dtype=adtype)

        # Output buffer
        Y_heads = torch.empty(B, T, Hh, Dh, device=device, dtype=dtype)

        eps = torch.as_tensor(self.eps, device=device, dtype=adtype)
        tiny = torch.finfo(adtype).tiny

        # ---- 3) Sweep over time t, stream over prefix i in blocks ----
        for t in range(T):
            # Update prefix states with token t (per head)
            KK_t = KK[:, t].to(adtype)   # (B, Hh, Dh)
            VK_t = VK[:, t].to(adtype)
            KV_t = KV[:, t].to(adtype)
            VV_t = VV[:, t].to(adtype)

            # S_t^K += KK_t^T @ VK_t ; S_t^V += KV_t^T @ VV_t
            S_K = S_K + torch.einsum('bhd,bhe->bhde', KK_t, VK_t)
            Z_K = Z_K + KK_t
            S_V = S_V + torch.einsum('bhd,bhe->bhde', KV_t, VV_t)
            Z_V = Z_V + KV_t

            # Current query (for logits) at time t
            Q_t = Q[:, t]   # (B, Hh, Dh)

            # Streaming softmax accumulators per (B,Hh)
            # m: running max logits; l: sum of exp; y: weighted sum of v_i^(t)
            m = torch.full((B, Hh), -float('inf'), device=device, dtype=adtype)
            l = torch.zeros((B, Hh), device=device, dtype=adtype)
            y = torch.zeros((B, Hh, Dh), device=device, dtype=adtype)

            # Iterate prefix blocks i in [0..t]
            for i0 in range(0, t + 1, self.block_size):
                i1 = min(i0 + self.block_size, t + 1)
                Bi = i1 - i0

                QK_blk = QK[:, i0:i1].to(adtype)  # (B, Bi, Hh, Dh)
                QV_blk = QV[:, i0:i1].to(adtype)  # (B, Bi, Hh, Dh)

                # Numerators and denominators (per head)
                # u_i = QK_i @ S_K ; alpha_i = QK_i · Z_K
                U_blk = torch.einsum('bihd,bhde->bihe', QK_blk, S_K)            # (B, Bi, Hh, Dh)
                alpha_blk = torch.einsum('bihd,bhd->bih', QK_blk, Z_K) + eps    # (B, Bi, Hh)

                # s_i = (Q_t · u_i) / alpha_i
                s_blk = torch.einsum('bihe,bhe->bih', U_blk, Q_t.to(adtype)) / alpha_blk  # (B,Bi,Hh)

                # v_i^(t) = (QV_i @ S_V) / (QV_i · Z_V)
                W_blk = torch.einsum('bihd,bhde->bihe', QV_blk, S_V)            # (B, Bi, Hh, Dh)
                beta_blk = torch.einsum('bihd,bhd->bih', QV_blk, Z_V) + eps     # (B, Bi, Hh)
                V_blk = W_blk / beta_blk[..., None]                             # (B, Bi, Hh, Dh)

                # Flash-style streaming softmax update across the block
                s_blk_max = s_blk.max(dim=1).values   # (B, Hh)
                m_next = torch.maximum(m, s_blk_max)
                exp_scale = torch.exp((m - m_next).clamp_min(torch.log(tiny)))  # prevent underflow

                # sum_i exp(s_i - m_next)
                exp_scores = torch.exp(s_blk - m_next.unsqueeze(1))
                l = exp_scale * l + exp_scores.sum(dim=1)

                # y = exp_scale * y + sum_i exp(s_i - m_next) * v_i
                y = exp_scale.unsqueeze(-1) * y + torch.einsum('bih,bihe->bhe', exp_scores, V_blk)

                m = m_next

            # Normalize to get output for time t
            Y_heads[:, t] = (y / (l.unsqueeze(-1) + tiny)).to(dtype)

        # Concat heads -> (B, T, d_model)
        Y = Y_heads.reshape(B, T, Hh * Dh)
        return Y.squeeze(0) if squeeze_batch else Y


In [71]:
# ---- Sanity check script ----
import torch

# Reuse your naive implementation
# (Assuming your LinKVBlockMultiHead class is already defined in scope)

def sanity_check(seed=0, B=2, T=32, Hh=4, Dh=16, device=None, dtype=torch.float32):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(seed)

    D = Hh * Dh
    X = torch.randn(B, T, D, device=device, dtype=dtype)

    naive = LinKVBlockMultiHead(d_model=D, num_heads=Hh, head_dim=Dh).to(device, dtype)
    flash = LinKVBlockMultiHeadFlash(d_model=D, num_heads=Hh, head_dim=Dh,
                                     block_size=8, eps=1e-6).to(device, dtype)
    flash.load_from(naive)  # ensure identical weights

    Y_naive = naive(X)
    Y_flash = flash(X)

    max_err = (Y_naive - Y_flash).abs().max().item()
    rel_err = ( (Y_naive - Y_flash).abs() / (Y_naive.abs().clamp_min(1e-5)) ).max().item()

    print(f"device={device} dtype={dtype} B={B} T={T} H={Hh} Dh={Dh}")
    print(f"max abs error: {max_err:.3e}, max rel error: {rel_err:.3e}")

    # Be reasonably strict; adjust tol if using fp16/bfloat16
    assert max_err < 2e-4 or rel_err < 2e-3, "Flash and naive outputs differ too much!"

# if __name__ == "__main__":
    # A couple of sizes
sanity_check(seed=0, B=1, T=16, Hh=2, Dh=8)
sanity_check(seed=1, B=2, T=32, Hh=4, Dh=16)
    # Try on GPU/bfloat16 or float16 if you'd like:
    # sanity_check(seed=2, B=2, T=64, Hh=8, Dh=16, dtype=torch.bfloat16)


TypeError: log(): argument 'input' (position 1) must be Tensor, not float