In [1]:
import torch
import triton
import triton.language as tl
import math
from copy import deepcopy
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
# from nsa_attn import NsaAttention
from flash_attn import flash_attn_func as fa2
from flash_attn_interface import flash_attn_func as fa3
from exp_family import compute_p, compute_select_p, _attention, fused_p, _cattention, select_for_bwd, select_for_fwd_bwd
from select_attn import select_attn
import random
from fla.ops.nsa import parallel_nsa
from pku_nsa import parallel_nsa_topk

n = 1024 * 4
kernel_size = 32
stride = 16
select_size = 64
top_n = 16
b, qh, kh, d, vd = 1, 64, 4, 128, 128
sm_scale = d ** -0.5
device = 'cuda'
dtype = torch.bfloat16
num_blocks = (n - kernel_size) // stride + 1
q = torch.randn(b, n, qh, d, device=device, dtype=dtype)
k = torch.randn(b, n, kh, d, device=device, dtype=dtype)
v = torch.randn(b, n, kh, vd, device=device, dtype=dtype)
ck = torch.randn(b, num_blocks, kh, d, device=device, dtype=dtype)
cv = torch.randn(b, num_blocks, kh, vd, device=device, dtype=dtype)
lse = torch.rand(b, qh, n, device=device, dtype=torch.float32) + 4

  from .autonotebook import tqdm as notebook_tqdm


# cmp_attn

In [9]:
q.requires_grad_(True)
ck.requires_grad_(True)
cv.requires_grad_(True)
q1 = deepcopy(q)
ck1 = deepcopy(ck)
cv1 = deepcopy(cv)
y1, lse1 = _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 1)
y2, lse2 = _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 2)
dy = torch.rand_like(y1)
print(torch.allclose(y1, y2), torch.allclose(lse1, lse2))

print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, ck, cv]))
print(triton.testing.do_bench(lambda: y2.backward(dy, retain_graph=True), grad_to_none=[q1, ck1, cv1]))

True True
3.978924512863159
3.32159686088562


In [4]:
print(triton.testing.do_bench(lambda: _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 1)))
print(triton.testing.do_bench(lambda: _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 2)))
print(triton.testing.do_bench(lambda: fa2(q, ck, cv, causal=False)))

0.3595395088195801
0.3616417944431305
0.428196519613266


# compute_attn_p

In [3]:
# print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=1)))
# print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=2)))# 
print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=3)))
print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=4)))

11.327719688415527
10.787739753723145


# compute_select_indices

In [4]:
prob = compute_p(q, ck, lse, kernel_size, stride, method=4)
prob2 = prob.transpose(-1, -2).contiguous()
prob.shape

torch.Size([1, 4, 255, 4096])

In [3]:
print(triton.testing.do_bench(lambda: compute_select_p(prob2, kernel_size, stride, select_size)))
print(triton.testing.do_bench(lambda: compute_select_p(prob, kernel_size, stride, select_size, method=2)))
print(triton.testing.do_bench(lambda: compute_select_p(prob, kernel_size, stride, select_size,top_n=16, method=3)))
print(triton.testing.do_bench(lambda: compute_select_p(prob2, kernel_size, stride, select_size, top_n=16, method=4)))

5.338936805725098
2.9004673957824707
2.299271583557129
0.619713544845581


In [5]:
print(triton.testing.do_bench(lambda: select_for_fwd_bwd(prob2, kernel_size, stride, select_size, top_n=16)))

0.02709120884537697


In [None]:
p, fwd_ind, bwd_ind, count = select_for_fwd_bwd(prob2, kernel_size, stride, select_size, top_n=16)
bwd_ind2, count2 = select_for_bwd(fwd_ind.to(torch.int64))
print(torch.allclose(count, count2[..., :-1].to(torch.int32)))
for _ in range(100):
    b_idx = random.randint(0, b-1)
    h_idx = random.randint(0, kh-1)
    row_idx = random.randint(0, n//select_size - 1)
    cnt = count[b_idx, h_idx, row_idx]
    val1, ind1 = bwd_ind[b_idx, h_idx, row_idx, :cnt].sort(-1)
    val2, ind2 = bwd_ind2[b_idx, h_idx, row_idx, :cnt].sort(-1)
    assert torch.allclose(val1.int(), val2.int())

True


In [None]:
print(triton.testing.do_bench(lambda: fused_p(q, ck, lse, kernel_size, stride, select_size, top_n=3, sm_scale=sm_scale, method=2)))

In [None]:
print(triton.testing.do_bench(lambda: fused_p(q, ck, lse, kernel_size, stride, select_size, top_n=3, sm_scale=sm_scale, method=4)))

In [6]:
p1, indices1 = compute_select_p(prob, kernel_size, stride, select_size, top_n=16, method=3)
p2, indices2 = compute_select_p(prob2, kernel_size, stride, select_size,top_n=16, method=4, return_p=True)
print(triton.testing.do_bench(lambda: select_for_bwd(indices2)))
torch.allclose(p1.transpose(-1, -2), p2)

9.823797225952148


True

In [7]:
(indices2 == fwd_ind).sum() / indices2.numel()

tensor(1., device='cuda:0')

In [24]:
torch.allclose(count, count2[0, :, :-1].to(torch.int32))

True

# select_attn

In [2]:
prob = compute_p(q, ck, lse, kernel_size, stride, method=4).transpose(-1, -2).contiguous()
_, indices = compute_select_p(prob, kernel_size, stride, select_size,top_n=top_n, method=4, return_p=False)
indices2 = indices.transpose(1,2).contiguous()
indices.shape


torch.Size([1, 4, 4096, 16])

In [3]:
y1 = parallel_nsa(q, k, v, indices2, select_size)
y2 = _attention.apply(q, k, v, select_size, indices, sm_scale, 1)
# y3 = _attention.apply(q, k, v, select_size, indices, sm_scale, 2)

Triton autotuning for function parallel_nsa_fwd_kernel finished after 0.88s; best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;
Triton autotuning for function _fwd_kernel1 finished after 2.78s; best config selected: num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;


In [3]:
print(triton.testing.do_bench(lambda: parallel_nsa(q, k, v, indices2, select_size)))

Triton autotuning for function parallel_nsa_fwd_kernel finished after 0.91s; best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;
0.8710374236106873


In [6]:
print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 1)))
# print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 2)))
# print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 3)))
# print(triton.testing.do_bench(lambda: fa2(q, k[:, :select_size*top_n], v[:, :select_size*top_n])))

1.8875794410705566


In [5]:
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
q1 = deepcopy(q)
k1 = deepcopy(k)
v1 = deepcopy(v)
y1 = _attention.apply(q, k, v, select_size, indices, sm_scale, 1)
y2 = _attention.apply(q1, k1, v1, select_size, indices, sm_scale, 3)
dy = torch.randn_like(y1)

In [6]:
print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, k, v]))

16.976537704467773


In [33]:
print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, k, v]))
# print(triton.testing.do_bench(lambda: y2.backward(dy, retain_graph=True), grad_to_none=[q1, k1, v1]))

18.112581253051758


In [18]:
y1[0, :, 0]

tensor([[-1.0469e+00,  2.6250e+00, -7.0801e-02,  ..., -4.1211e-01,
         -1.7812e+00,  1.8457e-01],
        [-1.0234e+00,  2.1094e+00,  8.7891e-02,  ...,  5.8594e-01,
         -1.2812e+00,  1.0234e+00],
        [-4.7852e-01, -6.1719e-01, -8.5547e-01,  ...,  8.4375e-01,
         -3.7695e-01,  1.0391e+00],
        ...,
        [-4.9316e-02,  6.3477e-02,  3.4668e-02,  ..., -3.4668e-02,
         -1.4282e-02, -2.5146e-02],
        [ 8.8501e-03,  7.6599e-03, -5.8838e-02,  ..., -2.3071e-02,
          1.0452e-03, -1.3855e-02],
        [ 1.3855e-02,  2.4536e-02, -1.1230e-01,  ...,  2.9907e-03,
         -4.4678e-02,  4.2236e-02]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)

In [19]:
y2[0, :, 0]

tensor([[-1.0469e+00,  2.6250e+00, -7.0801e-02,  ..., -4.1211e-01,
         -1.7812e+00,  1.8457e-01],
        [-1.0234e+00,  2.1094e+00,  8.7891e-02,  ...,  5.8594e-01,
         -1.2812e+00,  1.0234e+00],
        [-4.7852e-01, -6.1719e-01, -8.5547e-01,  ...,  8.4375e-01,
         -3.7695e-01,  1.0391e+00],
        ...,
        [-4.9316e-02,  6.3477e-02,  3.4668e-02,  ..., -3.4668e-02,
         -1.4282e-02, -2.5146e-02],
        [ 8.8501e-03,  7.6599e-03, -5.8838e-02,  ..., -2.3071e-02,
          1.0452e-03, -1.3855e-02],
        [ 1.3855e-02,  2.4536e-02, -1.1230e-01,  ...,  2.9907e-03,
         -4.4678e-02,  4.2236e-02]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)