In [146]:
import torch
from torch import Tensor


def doc_ids_to_cu_seqlen(doc_ids: Tensor) -> Tensor:
    """
    Convert document IDs to cumulative sequence lengths for use in attention mechanisms.

    Args:
        doc_ids (Tensor): A tensor of shape (batch_size, seq_len) containing document IDs.

    Returns:
        Tensor: A 1D tensor of cumulative sequence lengths.
    """
    batch_size, seq_len = doc_ids.shape
    # Find where groups change
    is_new_group = torch.cat(
        [
            torch.ones_like(doc_ids[:, :1], dtype=torch.bool),
            doc_ids[:, 1:] != doc_ids[:, :-1],
        ],
        dim=1,
    )

    # Get the indices where new groups start
    group_start_indices = torch.where(is_new_group)[1]
    return torch.cat(
        [group_start_indices, torch.tensor([seq_len], device=doc_ids.device)]
    )

In [147]:
# -*- coding: utf-8 -*-

import torch


@torch.jit.script
def normalize_output(q: torch.Tensor, k: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
    k = k.cumsum(1)
    z = (q * k).sum(-1, keepdim=True)
    return o / (z + 1e-10)

In [148]:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang


import torch

from fla.ops.simple_gla import chunk_simple_gla


@torch.compiler.disable
def chunk_linear_attn(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    scale: float | None = None,
    initial_state: torch.Tensor | None = None,
    output_final_state: bool = False,
    normalize: bool = True,
    head_first: bool = False,
    cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, T, H, K]`.
        k (torch.Tensor):
            keys of shape `[B, T, H, K]`.
        v (torch.Tensor):
            values of shape `[B, T, H, V]`.
        scale (Optional[float]):
            Scale factor for the linear attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `[B, H, K, V]`. Default: `None`.
        output_final_state (Optional[bool]):
            Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
        normalize (bool):
            Whether to normalize the output. Default: `True`.
        head_first (Optional[bool]):
            Whether the inputs are in the head-first format. Default: `False`.
            This argument has been deprecated.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, H, V]`.
        final_state (torch.Tensor):
            Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
    """

    if head_first:
        raise DeprecationWarning(
            "head_first is deprecated and will be removed in a future version. "
            "Please use head_first=False for now instead.",
        )
    if not head_first:
        if q.shape[1] < q.shape[2]:
            raise DeprecationWarning(
                f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
                "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
                "when head_first=False was specified. "
                "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
            )
    if scale is None:
        scale = k.shape[-1] ** -0.5
    o, final_state = chunk_simple_gla(
        q=q,
        k=k,
        v=v,
        scale=scale,
        initial_state=initial_state,
        output_final_state=output_final_state,
        cu_seqlens=cu_seqlens,
    )
    if normalize:
        o = normalize_output(q * scale, k, o)
    return o, final_state

In [None]:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
"""

from __future__ import annotations

import torch
import torch.nn as nn
from einops import rearrange

from fla.modules.feature_map import RebasedFeatureMap
from fla.ops.linear_attn import fused_chunk_linear_attn
from fla.ops.rebased import parallel_rebased


class ReBasedLinearAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        l_max: int = 2048,
        feature_dim: int = 16,
        num_key_value_heads: int = 16,
        num_heads: int = 16,
        use_gamma: bool | None = True,
        use_beta: bool | None = True,
        normalize: bool | None = True,
        causal: bool = True,
        eps: float = 1e-5,
        mode: str = "parallel",
        layer_idx: int | None = None,
        **kwargs,
    ) -> ReBasedLinearAttention:
        super().__init__()
        self.hidden_size = hidden_size
        self.l_max = l_max
        self.mode = mode
        assert self.mode in ["fused_chunk", "parallel", 'chunk']

        self.feature_dim = feature_dim
        self.num_key_value_heads = num_key_value_heads
        self.num_heads = num_heads
        self.head_dim = self.hidden_size // self.num_key_value_heads
        self.use_gamma = use_gamma
        self.use_beta = use_beta
        self.normalize = normalize
        self.causal = causal
        self.eps = eps
        self.mode = mode
        self.layer_idx = layer_idx

        self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
        self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self.dropout = nn.Identity()

    def forward(self, hidden_states: torch.Tensor, **kwargs):
        mode = self.mode
        q = rearrange(
            self.q_proj(hidden_states),
            "... (h d) -> ... h d",
            h=self.num_heads,
            d=self.feature_dim,
        )
        k = rearrange(
            self.k_proj(hidden_states),
            "... (h d) -> ... h d",
            h=self.num_heads,
            d=self.feature_dim,
        )
        v = rearrange(
            self.v_proj(hidden_states),
            "... (h d) -> ... h d",
            h=self.num_key_value_heads,
            d=self.head_dim,
        )
        cu_seqlens = kwargs.get("cu_seqlens", None)
        q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
        if mode == "fused_chunk":
            o,_ = fused_chunk_linear_attn(
                q=q,
                k=k,
                v=v,
                normalize=True,
                scale=1,
                cu_seqlens=cu_seqlens,
            )
        elif mode == 'chunk':
            o,_ = chunk_linear_attn(
                q=q,
                k=k,
                v=v,
                normalize=True,
                scale=1,
                cu_seqlens=cu_seqlens,
            )
        elif mode == 'parallel':
            assert q.shape[-1] <= 128
            o = parallel_rebased(
                q=q,
                k=k,
                v=v,
                eps=self.eps,
                use_scale=True,
                use_normalize=True,
            )
        o = rearrange(o, "... h d -> ... (h d)")
        o = self.o_proj(o)
        o = self.dropout(o)
        return o

In [150]:
s1 = 16
s2 = 18
b,s,d = 1,s1+s2,512
num_heads = 8
mode = "chunk"

In [None]:
rebased = ReBasedLinearAttention(
    hidden_size=d,
    num_heads=num_heads,
    mode=mode,
    num_key_value_heads=num_heads,
)

In [152]:
x = torch.randn(b,s,d)
doc_ids = []
index = 0
for seq_len in [s1, s2]:
    doc_ids.extend([index] * seq_len)
    index += 1
doc_ids = torch.tensor([doc_ids])
cu_seqlens = doc_ids_to_cu_seqlen(doc_ids)

In [153]:
x[:, :s1, :] = -1e9

In [154]:
rebased = rebased.cuda()
x = x.cuda()
cu_seqlens = cu_seqlens.cuda()

In [155]:
cu_seqlens

tensor([ 0, 16, 34], device='cuda:0')

In [156]:
out = rebased(x, cu_seqlens=cu_seqlens)

  return fn(*args, **kwargs)


In [157]:
out.shape

torch.Size([1, 34, 512])

In [158]:
out.mean().backward()

In [159]:
first = x[:, :s1, :]
second = x[:, s1:, :]

In [160]:
first_out = rebased(first)
second_out = rebased(second)

In [161]:
first_out, second_out

(tensor([[[ 1.6224e+08,  3.7914e+08,  4.8357e+08,  ...,  1.5805e+07,
            2.0560e+08, -5.6705e+08],
          [ 1.6224e+08,  3.7914e+08,  4.8357e+08,  ...,  1.5805e+07,
            2.0560e+08, -5.6705e+08],
          [ 1.6223e+08,  3.7915e+08,  4.8359e+08,  ...,  1.5810e+07,
            2.0562e+08, -5.6707e+08],
          ...,
          [ 1.6220e+08,  3.7928e+08,  4.8374e+08,  ...,  1.5856e+07,
            2.0574e+08, -5.6723e+08],
          [ 1.6219e+08,  3.7928e+08,  4.8374e+08,  ...,  1.5856e+07,
            2.0574e+08, -5.6723e+08],
          [ 1.6222e+08,  3.7921e+08,  4.8366e+08,  ...,  1.5830e+07,
            2.0567e+08, -5.6714e+08]]], device='cuda:0',
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[-0.1034, -0.5111, -0.2731,  ...,  0.0754,  0.0731,  0.2537],
          [ 0.0235,  0.1614, -0.2155,  ..., -0.0512, -0.1835,  0.3704],
          [ 0.0539, -0.0804, -0.2118,  ...,  0.0968, -0.0406,  0.4769],
          ...,
          [-0.0614,  0.1224, -0.2283,  ...,  0.2209, 

In [162]:
out[:, :s1, :], out[:, s1:, :]

(tensor([[[ 1.6224e+08,  3.7914e+08,  4.8357e+08,  ...,  1.5805e+07,
            2.0560e+08, -5.6705e+08],
          [ 1.6224e+08,  3.7914e+08,  4.8357e+08,  ...,  1.5805e+07,
            2.0560e+08, -5.6705e+08],
          [ 1.6223e+08,  3.7915e+08,  4.8359e+08,  ...,  1.5810e+07,
            2.0562e+08, -5.6707e+08],
          ...,
          [ 1.6220e+08,  3.7928e+08,  4.8374e+08,  ...,  1.5856e+07,
            2.0574e+08, -5.6723e+08],
          [ 1.6219e+08,  3.7928e+08,  4.8374e+08,  ...,  1.5856e+07,
            2.0574e+08, -5.6723e+08],
          [ 1.6222e+08,  3.7921e+08,  4.8366e+08,  ...,  1.5830e+07,
            2.0567e+08, -5.6714e+08]]], device='cuda:0',
        grad_fn=<SliceBackward0>),
 tensor([[[-0.2436, -0.3086, -0.0769,  ...,  0.0433,  0.1269,  0.2462],
          [-0.1647,  0.0130, -0.0547,  ..., -0.1089, -0.0193,  0.1714],
          [-0.0200,  0.0212, -0.0332,  ..., -0.0176,  0.0199,  0.1294],
          ...,
          [-0.0648,  0.0329, -0.0858,  ...,  0.0747, -0.02

In [163]:
torch.allclose(first_out, out[:, :s1, :])

True

In [164]:
torch.allclose(second_out, out[:, s1:, :])

False

In [165]:
# doesn't match but it doesn't carry information over either apparently
# as setting the first half to -1e9 seems to have no effect on the second half output. Maybe this is fine???