In [1]:
import torch
import torch.nn.functional as F

from ref.modeling import (
    AttnQKVPackFormat,
    AttnQKVLayout,
    OfflineSlidingWindowAttn, 
    OnlineSlidingWindowAttn,
)

In [3]:
a = torch.randn((100,))
b = (1 + a.exp()).log()
c = a.exp().log1p()
d = F.softplus(a)

torch.allclose(b, c, equal_nan=True); torch.allclose(b, d, equal_nan=True)

True

In [9]:
a = torch.randn((100, 10000), dtype=torch.float32)
b = torch.rsqrt(a)
c = torch.rsqrt(a.cuda())

In [10]:
torch.allclose(b.cuda(), c, equal_nan=True)

True

In [2]:
def construct_offline_attn_args(
    b: int,
    sq: int,
    skv: int,
    hq: int,
    hkv: int,
    hd: int,
    qkv_pack_format: AttnQKVPackFormat,
    qkv_layout: AttnQKVLayout,
    seqlens_q = None,
    seqlens_kv = None,
    dtype: torch.dtype = torch.float32,
    device: str = "cpu",
    seed: int = 42,
):
    torch.manual_seed(seed)
    q = torch.randn((b, sq, hq, hd), dtype=dtype, device=device)
    k = torch.randn((b, skv, hkv, hd), dtype=dtype, device=device)
    v = torch.randn((b, skv, hkv, hd), dtype=dtype, device=device)
    
    if qkv_layout == AttnQKVLayout.THD:
        assert seqlens_q is not None, "THD layout requires cu_seqlens_q"
        assert seqlens_kv is not None, "THD layout requires cu_seqlens_kv"
        
        cu_seqlens_q, cu_seqlens_kv =[
            torch.concat([
                torch.zeros(1, dtype=torch.int32, device=device),
                torch.tensor(x, dtype=torch.int32, device=device).cumsum(dim=0)
            ], dim=0)
            for x in (seqlens_q, seqlens_kv)
        ]
        
        q, k, v = [
            x.view(-1, *x.shape[-2:]).contiguous() 
            for x in (q, k, v)
        ]
    else:
        assert seqlens_q is None, "QKV layout does not require cu_seqlens_q"
        assert seqlens_kv is None, "QKV layout does not require cu_seqlens_kv"
        cu_seqlens_q, cu_seqlens_kv = None, None
        
        if qkv_layout == AttnQKVLayout.SBHD:
            q, k, v = [
                x.transpose(0, 1).contiguous() 
                for x in (q, k, v)
            ]
    
    if qkv_pack_format == AttnQKVPackFormat.QKV:
        assert sq == skv, "QKV pack format requires sq == skv"
        q = torch.concat((q, k, v), dim=-2)
        k, v = None, None
    elif qkv_pack_format == AttnQKVPackFormat.Q_KV:
        k = torch.concat((k, v), dim=-2)
        v = None
    
    return q, k, v, cu_seqlens_q, cu_seqlens_kv

In [3]:
def construct_online_attn_args(
    b: int,
    sq: int,
    skv: int,
    hq: int,
    hkv: int,
    hd: int,
    bq: int,
    bkv: int,
    bqi: int,
    bkvi: int,
    dtype: torch.dtype = torch.float32,
    device: str = "cpu",
    seed: int = 42,
):
    torch.manual_seed(seed)
    
    q = torch.randn((b, sq, hq, hd), dtype=dtype, device=device)
    k = torch.randn((b, skv, hkv, hd), dtype=dtype, device=device)
    v = torch.randn((b, skv, hkv, hd), dtype=dtype, device=device)
    global_o = torch.randn_like(q)
    global_lse = torch.rand((b, hq, sq), dtype=torch.float32, device=device)
    
    nbq = (sq + bq - 1) // bq
    nbk = (skv + bkv - 1) // bkv

    q = F.pad(q, pad=(0, 0, 0, 0, 0, nbq*bq - sq), mode="constant", value=0)
    k = F.pad(k, pad=(0, 0, 0, 0, 0, nbk*bkv - skv), mode="constant", value=0)
    v = F.pad(v, pad=(0, 0, 0, 0, 0, nbk*bkv - skv), mode="constant", value=0)
    
    q = q[:, bqi*bq:(bqi+1)*bq, :, :]
    k = k[:, bkvi*bkv:(bkvi+1)*bkv, :, :]
    v = v[:, bkvi*bkv:(bkvi+1)*bkv, :, :]
    
    return q, k, v, global_o, global_lse

In [4]:
# ## task1 - toy case1

# b = 1
# sq, skv = 6, 6
# hq, hkv = 1, 1
# hd = 4

# window_size = None
# causal = True

# softmax_dropout_rate = 0.1
# softmax_scale = None
# softmax_cap = None
# softmax_temp = 0.8
# softmax_clip_range = (-0.03, 1.03)

# qkv_pack_format = AttnQKVPackFormat.QKV
# qkv_layout = AttnQKVLayout.SBHD

# seqlens_q = None
# seqlens_kv = None

# group_size = 1
# init_range = (-1.1, 1.1)

# act_dtype=torch.bfloat16

In [5]:
# ### task1 - toy case2

# b = 1
# sq, skv = 7, 5
# hq, hkv = 2, 1
# hd = 4

# window_size = 1
# causal = False

# softmax_dropout_rate = 0.0
# softmax_scale = None
# softmax_cap = 10
# softmax_temp = 1.0
# softmax_clip_range = (-0.01, 1.01)

# qkv_pack_format = AttnQKVPackFormat.Q_KV
# qkv_layout = AttnQKVLayout.THD

# seqlens_q = [1, 2, 4]
# seqlens_kv = [2, 2, 1]

# group_size = 2
# init_range = (-1.2, 1.2)

# act_dtype=torch.float32

In [6]:
# ### task2 - toy case1

# b = 1
# sq, skv = 7, 5
# hq, hkv = 1, 1
# hd = 4

# bq, bkv = 3, 2
# bqi_, bkvi_ = 1, 1

# window_size = 2
# causal = True

# softmax_scale = None
# softmax_dropout_rate = 0.0
# softmax_cap = 10
# softmax_temp = 1.0
# softmax_clip_range = (0., 1.)

# qkv_pack_format = AttnQKVPackFormat.Q_K_V
# qkv_layout = AttnQKVLayout.BSHD

# seqlens_q = None
# seqlens_kv = None

# group_size = 2
# init_range = (-1.05, 1.05)

# act_dtype=torch.float32

In [7]:
### task2 - toy case2

b = 1
sq, skv = 7, 5
hq, hkv = 1, 1
hd = 4

bq, bkv = 3, 2
bqi_, bkvi_ = 2, 0

window_size = 1
causal = False

softmax_scale = None
softmax_dropout_rate = 0.0
softmax_cap = None
softmax_temp = 0.9
softmax_clip_range = (0., 1.)

qkv_pack_format = AttnQKVPackFormat.Q_K_V
qkv_layout = AttnQKVLayout.BSHD

seqlens_q = None
seqlens_kv = None

group_size = 1
init_range = (-1.25, 1.25)

act_dtype=torch.float32

In [8]:
q, k, v, cu_seqlens_q, cu_seqlens_kv = \
    construct_offline_attn_args(
        b, sq, skv, hq, hkv, hd, 
        qkv_pack_format, qkv_layout,
        seqlens_q, seqlens_kv,
        dtype=act_dtype,
    )

In [9]:
common_kwargs = {
    "head_dim": hd,
    "num_q_head": hq,
    "num_kv_head": hkv,
    "window_size": window_size,
    "causal": causal,
    "softmax_scale": softmax_scale,
    "softmax_cap": softmax_cap,
    "softmax_temp": softmax_temp,
    "group_size": group_size,
    "eps": 1e-5,
    "init_range": init_range,
    "init_seed": 42,
    "dtype": torch.float32,
    "device": "cpu",
}

In [10]:
off_swa = OfflineSlidingWindowAttn(
    qkv_pack_format=qkv_pack_format,
    qkv_layout=qkv_layout,
    softmax_dropout_rate=softmax_dropout_rate,
    softmax_dropout_seed=42,
    softmax_clip_range=softmax_clip_range,
    **common_kwargs,
)

In [11]:
o = off_swa(q, k, v, cu_seqlens_q, cu_seqlens_kv)
o, o.shape

(tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000]],
 
          [[ 0.0109, -0.3387, -1.3407, -0.5854]],
 
          [[ 0.1202, -0.1787, -0.9227, -0.5619]],
 
          [[ 0.0766, -0.1375, -0.8620, -0.5507]],
 
          [[-0.1484,  0.8362,  0.2361, -0.3500]],
 
          [[-0.3544,  0.7886,  0.1344, -0.2238]],
 
          [[-0.1641,  0.8789,  0.1191,  0.0121]]]],
        grad_fn=<TransposeBackward0>),
 torch.Size([1, 7, 1, 4]))

In [12]:
on_swa = OnlineSlidingWindowAttn(
    block_size_q=bq,
    block_size_kv=bkv,
    seqlen_q=sq,
    seqlen_kv=skv,
    **common_kwargs,
)

In [13]:
on_swa.global_attn_mask, on_swa.global_attn_mask.shape

(tensor([[[[-inf, -inf, -inf, -inf, -inf, -inf],
           [0., -inf, -inf, -inf, -inf, -inf],
           [0., 0., -inf, -inf, -inf, -inf],
           [0., 0., 0., -inf, -inf, -inf],
           [-inf, 0., 0., 0., -inf, -inf],
           [-inf, -inf, 0., 0., 0., -inf],
           [-inf, -inf, -inf, 0., 0., -inf],
           [-inf, -inf, -inf, -inf, -inf, -inf],
           [-inf, -inf, -inf, -inf, -inf, -inf]]]]),
 torch.Size([1, 1, 9, 6]))

In [14]:
on_swa.global_attn_mask[..., :7, :5]

tensor([[[[-inf, -inf, -inf, -inf, -inf],
          [0., -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf],
          [-inf, 0., 0., 0., -inf],
          [-inf, -inf, 0., 0., 0.],
          [-inf, -inf, -inf, 0., 0.]]]])

In [15]:
nbq = (sq + bq - 1) // bq
nbk = (skv + bkv - 1) // bkv

q = F.pad(q, pad=(0, 0, 0, 0, 0, nbq*bq - sq), mode="constant", value=0)
k = F.pad(k, pad=(0, 0, 0, 0, 0, nbk*bkv - skv), mode="constant", value=0)
v = F.pad(v, pad=(0, 0, 0, 0, 0, nbk*bkv - skv), mode="constant", value=0)

o_ = torch.zeros_like(o)
lse_ = torch.zeros((b, hq, sq), dtype=o_.dtype, device=o_.device)
lse_.fill_(float("-inf"))

q, k, v, o_, lse_

(tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055]],
 
          [[ 0.6784, -1.2345, -0.0431, -1.6047]],
 
          [[-0.7521,  1.6487, -0.3925, -1.4036]],
 
          [[-1.1109,  0.0915, -2.3169, -0.2168]],
 
          [[-1.3847, -0.8712, -0.2234,  1.7174]],
 
          [[-0.5920, -0.0631, -0.8286,  0.3309]],
 
          [[-1.5576,  0.9956, -0.8798, -0.6011]],
 
          [[ 0.0000,  0.0000,  0.0000,  0.0000]],
 
          [[ 0.0000,  0.0000,  0.0000,  0.0000]]]]),
 tensor([[[[ 1.3123,  0.6872, -1.0892, -0.3553]],
 
          [[ 1.4451,  0.8564,  2.2181,  0.5232]],
 
          [[ 0.3466, -0.1973, -1.0546,  1.2780]],
 
          [[-0.1722,  0.5238,  0.0566,  0.4263]],
 
          [[ 0.5750, -0.6417, -2.2064, -0.7508]],
 
          [[ 0.0000,  0.0000,  0.0000,  0.0000]]]]),
 tensor([[[[ 1.0868e-02, -3.3874e-01, -1.3407e+00, -5.8537e-01]],
 
          [[ 6.4076e-01,  5.8325e-01,  1.0669e+00, -4.5015e-01]],
 
          [[-6.7875e-01,  5.7432e-01,  1.8775e-01, -3.5762e-01]],
 
          [[ 2

In [None]:
for bqi in range(nbq):
    for bkvi in range(nbk):
        print(f"block idx: q: {bqi} | kv: {bkvi}")
        q_ = q[:, bqi*bq:(bqi+1)*bq, :, :]
        k_ = k[:, bkvi*bkv:(bkvi+1)*bkv, :, :]
        v_ = v[:, bkvi*bkv:(bkvi+1)*bkv, :, :]
        # print(f"q_.shape: {q_.shape} | k_.shape: {k_.shape} | v_.shape: {v_.shape}")
        on_swa(
            q=q_, 
            k=k_,
            v=v_,
            global_o=o_,
            global_lse=lse_,
            block_idx_q=bqi,
            block_idx_kv=bkvi
        )
        print(f"all global o: {o_} | all global lse: {lse_}")
    # break

block idx: q: 0 | kv: 0
all global o: tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0109, -0.3387, -1.3407, -0.5854]],

         [[ 0.1202, -0.1787, -0.9227, -0.5619]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]]]], grad_fn=<CopySlices>) | all global lse: tensor([[[  -inf, 0.6894, 1.0614,   -inf,   -inf,   -inf,   -inf]]],
       grad_fn=<CopySlices>)
block idx: q: 0 | kv: 1
all global o: tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0109, -0.3387, -1.3407, -0.5854]],

         [[ 0.1202, -0.1787, -0.9227, -0.5619]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000]]]], grad_fn=<CopySlices>) | all global lse: tensor([[[  -inf, 0.6894, 1.0614,   -inf,   

In [17]:
print(f"all global o: {o_}\n\nall global lse: {lse_}")

all global o: tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.0109, -0.3387, -1.3407, -0.5854]],

         [[ 0.1202, -0.1787, -0.9227, -0.5619]],

         [[ 0.0766, -0.1375, -0.8620, -0.5507]],

         [[-0.1484,  0.8362,  0.2361, -0.3500]],

         [[-0.3544,  0.7886,  0.1344, -0.2238]],

         [[-0.1641,  0.8789,  0.1191,  0.0121]]]], grad_fn=<CopySlices>)

all global lse: tensor([[[  -inf, 0.6894, 1.0614, 1.1172, 1.5804, 1.5584, 0.7451]]],
       grad_fn=<CopySlices>)


In [None]:
from torch.testing import assert_close

assert_close(o, o_)

In [19]:
# random global o, global lse, bqi, bkvi
q, k, v, global_o, global_lse = construct_online_attn_args(
    b, sq, skv, hq, hkv, hd, 
    bq, bkv, bqi_, bkvi_,
    dtype=act_dtype,
)

In [20]:
on_swa(
    q, 
    k,
    v,
    global_o,
    global_lse,
    block_idx_q=bqi_,
    block_idx_kv=bkvi_,
)

In [21]:
print(f"all global o: {global_o}\n\nall global lse: {global_lse}")

all global o: tensor([[[[ 0.7262,  0.0912, -0.3891,  0.5279]],

         [[ 1.0311, -0.7048,  1.0131, -0.3308]],

         [[ 1.0950,  0.3399,  0.7200,  0.4114]],

         [[-0.9727,  0.9585,  1.6192,  1.4506]],

         [[ 0.2695, -0.2104, -0.7328,  0.1043]],

         [[ 0.3488,  0.9676, -0.4657,  1.6048]],

         [[-2.4801, -0.4175, -1.1955,  0.8123]]]], grad_fn=<CopySlices>)

all global lse: tensor([[[0.9545, 0.6099, 0.5643, 0.0594, 0.7099, 0.4250, 0.2709]]],
       grad_fn=<CopySlices>)


In [18]:
a = [1,3,4]
torch.concat([
    torch.zeros(1, dtype=torch.int32, device="cpu"),
    torch.tensor(a, dtype=torch.int32, device="cpu").cumsum(dim=0)
], dim=0)

tensor([0, 1, 4, 8])

In [12]:
a = torch.randn(3, dtype=torch.float32)
b = torch.randn(3, dtype=torch.bfloat16)

(a * b).dtype

torch.float32

In [13]:
a = torch.tensor([float('-inf'), float('-inf'), float('-inf')])
F.softmax(a, dim=-1)

tensor([nan, nan, nan])

In [14]:
a = torch.tensor([float('-inf'), float('-inf'), float('-inf')])
a - a + 1e-10

tensor([nan, nan, nan])

In [15]:
a = torch.tensor([float('-inf'), float('-inf'), float('-inf')])
def safe_subtract(
    a: torch.Tensor,
    b: torch.Tensor,
) -> torch.Tensor:
    """Safely subtracts two tensors.
    where the subtraction results of two -inf will be set to -inf.
    """
    eq = ((a == b) & (a == float('-inf'))).all(dim=-1, keepdim=True)
    
    sub = a - b
    sub = torch.where(eq, torch.fill(sub, float('-inf')), sub)
    
    return sub

safe_subtract(a, a)

tensor([-inf, -inf, -inf])

In [16]:
import torch
import torch.nn.functional as F

def safe_softmax(a, dim=-1):
    all_neg_inf = (a == float('-inf')).all(dim=dim, keepdim=True)
    
    sm = F.softmax(a, dim=dim)
    
    sm = torch.where(all_neg_inf, torch.zeros_like(sm), sm)

    return sm

# 示例
a = torch.tensor([
    [10, float('-inf'), float('-inf')],
    [11, 31, 25],
    [float('-inf'), float('-inf'), float('-inf')]
])

safe_softmax(a, dim=-1)

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.0561e-09, 9.9753e-01, 2.4726e-03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00]])

In [17]:
f = lambda q, kv, _v: (
    q, 
    *torch.split(
        kv, 
        split_size_or_sections=2,
        dim=-2
    )
)

In [18]:
q = torch.randn(2, 4, 4)
kv = torch.randn(2, 4, 4)

f(q, kv, None)[1].shape

torch.Size([2, 2, 4])

In [19]:
[x+1 for x in [1,2,3]]

[2, 3, 4]

In [20]:
a = torch.tensor([1,2,3,4,5], dtype=torch.float)

In [21]:
sm_a = F.softmax(a, dim=-1)
sm_a

tensor([0.0117, 0.0317, 0.0861, 0.2341, 0.6364])

In [22]:
lse_a = a.exp().sum().log()
lse_a

tensor(5.4519)

In [23]:
lse_a = torch.logsumexp(a, dim=-1)
lse_a

tensor(5.4519)

In [24]:
a1 = a[:3]
a2 = a[3:]
a1, a2

(tensor([1., 2., 3.]), tensor([4., 5.]))

In [25]:
sm_a1 = F.softmax(a1, dim=-1)
sm_a2 = F.softmax(a2, dim=-1)
sm_a1, sm_a2

(tensor([0.0900, 0.2447, 0.6652]), tensor([0.2689, 0.7311]))

In [26]:
lse_a1 = a1.exp().sum().log()
lse_a2 = a2.exp().sum().log()
lse_a1, lse_a2

(tensor(3.4076), tensor(5.3133))

In [27]:
lse_a1 = torch.logsumexp(a1, dim=-1)
lse_a2 = torch.logsumexp(a2, dim=-1)
lse_a1, lse_a2

(tensor(3.4076), tensor(5.3133))

In [28]:
lse_a_ = (lse_a1.exp() + lse_a2.exp()).log()
lse_a_

tensor(5.4519)

In [29]:
max_lse = torch.max(lse_a1, lse_a2)
min_lse = torch.min(lse_a1, lse_a2)

lse_a_ = max_lse + torch.log(1 + torch.exp(min_lse - max_lse)) # stable version
lse_a_

tensor(5.4519)

In [30]:
sm_a1_ = sm_a1 * (lse_a1 - lse_a_).exp()
sm_a2_ = sm_a2 * (lse_a2 - lse_a_).exp()
sm_a1_, sm_a2_

(tensor([0.0117, 0.0317, 0.0861]), tensor([0.2341, 0.6364]))

In [31]:
import torch
a = torch.tensor([
    [True, True, False]
]).bool()

b = torch.zeros((1,3))

In [32]:
b.masked_fill_(a, 1)
b

tensor([[1., 1., 0.]])

In [14]:
import torch
sq, skv = 7, 5
w = 2
causal = True

# init attn mask, with shape: [sq, skv]
attn_mask = torch.zeros((sq, skv), dtype=torch.float)

# init q row-index and k col-index
maxs = max(sq, skv)
qi = torch.arange(maxs-sq, maxs).view(-1, 1)  # [sq, 1]
kj = torch.arange(maxs-skv, maxs).view(1, -1)  # [1, skv]

w = w if w is not None else maxs

# print(f"qi: {qi}")
# print(f"kj: {kj}")

# compute [lb, ub) of kj for each qi
# non causal: [i-w, i] | causal: [i-w, i+w]
lb = torch.clamp(
    qi - w,
    min=maxs-skv
)
ub = torch.clamp(
    qi + w + 1,
    max=maxs
) if not causal else (qi + 1)

# print(f"lb: {lb}")
# print(f"ub: {ub}")

# fill the attn mask
# where '0' means the position to keep,
# while '-inf' means the position to be masked out
attn_mask.masked_fill_(
    (kj < lb) | (kj >= ub),
    float("-inf")
)

# return with shape: (1, 1, sq, skv) to broadcast
# attn_mask.unsqueeze(0).unsqueeze(0)
attn_mask

tensor([[-inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf],
        [0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [-inf, 0., 0., 0., -inf],
        [-inf, -inf, 0., 0., 0.]])