# Native-Sparse-Attention

author: [dhcode-cpp](https://github.com/dhcode-cpp)

blog: [【手撕NSA】DeepSeek新作-原生稀疏注意力-超长文(附代码)](https://zhuanlan.zhihu.com/p/24841366485)

1. Compress Attention
2. Selection Attention
3. Sliding Window Attenion
4. Gated Aggregation
5. Stride sletection attention


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Config

In [2]:
t = 32 # token ids 
l = 8 # block
d = 8 # sliding stride
block_nums = t // l
dim = 16 # embedding dimension
heads = 4
head_dim = dim//heads
batch_size = 1

In [3]:
X = torch.randn(batch_size, t, dim)

Wq = torch.randn(dim, dim)
Wk = torch.randn(dim, dim)
Wv = torch.randn(dim, dim)

Q = X @ Wq
K = X @ Wk
V = X @ Wv

# skip apply rope

In [4]:
print(Q.shape)
print(K.shape)
print(V.shape)

torch.Size([1, 32, 16])
torch.Size([1, 32, 16])
torch.Size([1, 32, 16])


## Attention with different KV-len

## Token Compression

In [5]:
d = 4
max_idx = round(( t - l ) / d)
print(max_idx)
print(torch.arange(max_idx) * d + 1)
print(torch.arange(max_idx) * d + l)

6
tensor([ 1,  5,  9, 13, 17, 21])
tensor([ 8, 12, 16, 20, 24, 28])


In [6]:
d = l
max_idx = round(( t ) / d)
print(max_idx)
print(torch.arange(max_idx) * d + 1)
print(torch.arange(max_idx) * d + l)

4
tensor([ 1,  9, 17, 25])
tensor([ 8, 16, 24, 32])


In [7]:
W_K_cmp = torch.randn(l, 1)
W_V_cmp = torch.randn(l, 1)
W_pe = torch.randn(l, dim)

In [8]:
K_cmp = []
V_cmp = []
for i in range(max_idx):
    cur_K = K[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
    cur_V = V[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
    cur_K = cur_K.transpose(1, 2) @ W_K_cmp 
    cur_V = cur_V.transpose(1, 2) @ W_V_cmp
    K_cmp.append(cur_K)
    V_cmp.append(cur_V)

K_cmp = torch.cat(K_cmp, dim = 2).transpose(1,2)
V_cmp = torch.cat(V_cmp, dim = 2).transpose(1,2)
print(K_cmp.shape)
print(V_cmp.shape)

torch.Size([1, 4, 16])
torch.Size([1, 4, 16])


In [9]:
# multi-head attn
Q_mha = Q.view(1, t, heads, head_dim).transpose(1,2)
K_cmp_mha = K_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
V_cmp_mha = V_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
score_cmp = Q_mha @ K_cmp_mha.transpose(2,3) # bs, head, q_len, k_cmp_len
print(score_cmp.shape)

torch.Size([1, 4, 32, 4])


In [10]:
p_cmp = F.softmax(score_cmp, dim = -1) 
o_cmp = p_cmp @ V_cmp_mha
print(o_cmp.shape)

o_cmp = o_cmp.transpose(2, 1).reshape(batch_size, t, dim)
print(o_cmp.shape)

torch.Size([1, 4, 32, 4])
torch.Size([1, 32, 16])


## Token Selection

In [11]:
print(p_cmp.shape)
p_slc = p_cmp.sum(dim = 1)
print(p_slc.shape)

torch.Size([1, 4, 32, 4])
torch.Size([1, 32, 4])


In [12]:
select_top_k = 2
value, idx = torch.topk(p_slc, dim = 2, k = select_top_k)
print(idx[0,0,:])
idx.shape

tensor([1, 0])


torch.Size([1, 32, 2])

In [13]:
idx_slc_start = idx * d
idx_slc_end = idx * d + l
K_slc = torch.randn(batch_size, t, d * select_top_k, dim)
V_slc = torch.randn(batch_size, t, d * select_top_k, dim)
for i in range(batch_size):
    for j in range(t):
        for k in range(select_top_k):
            K_slc[i, j, k * d : k * d + l, :] = K[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
            V_slc[i, j, k * d : k * d + l, :] = V[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
print(K_slc.shape)
print(V_slc.shape)

torch.Size([1, 32, 16, 16])
torch.Size([1, 32, 16, 16])


In [14]:
# shared head KV
# IN GQA Group: [1-head KV & N-head Q] ----repeat kv-head---> [N-head KV & N-head Q]

V_slc_mha = V_slc.view(batch_size, t, select_top_k * d, heads, head_dim).transpose(2,3)
V_slc = V_slc_mha.sum(dim = 2, keepdim = True)
print(V_slc.shape) # bs, seq_len, head, select_seq_len, head_dim

K_slc_mha = K_slc.view(batch_size, t, select_top_k * d, heads, head_dim).transpose(2,3)
K_slc = K_slc_mha.sum(dim = 2, keepdim = True)
print(V_slc.shape) # bs, seq_len, head, select_seq_len, head_dim

torch.Size([1, 32, 1, 16, 4])
torch.Size([1, 32, 1, 16, 4])


In [15]:
# debug Q-1 and KV-16 attention
print(Q_mha.shape) # bs, head, seq, head_dim
print(Q_mha[:, :, 5, :].shape) # t=5
print(K_slc[:, 5, :, :, :].shape) # t=5

print(Q_mha[:, :, 5, :].unsqueeze(dim = 2).repeat(1, 1, select_top_k * d, 1).shape) # t=5
print(K_slc[:, 5, :, :, :].repeat(1, heads, 1, 1).shape) # t=5

Q_slc_j = Q_mha[:, :, 5, :].unsqueeze(dim = 2)
K_slc_j = K_slc[:, 5, :, :, :].repeat(1, heads, 1, 1)

attn_score_j = Q_slc_j @ K_slc_j.transpose(2,3)
print(attn_score_j.shape) # bs, head, seq_q, seq_slc_k

V_slc_j = V_slc[:, 5, :, :, :].repeat(1, heads, 1, 1)
print(V_slc_j.shape)

o_j = (attn_score_j @ V_slc_j).transpose(1,2).view(batch_size, 1, dim)
print(o_j.shape) # bs, j, dim

torch.Size([1, 4, 32, 4])
torch.Size([1, 4, 4])
torch.Size([1, 1, 16, 4])
torch.Size([1, 4, 16, 4])
torch.Size([1, 4, 16, 4])
torch.Size([1, 4, 1, 16])
torch.Size([1, 4, 16, 4])
torch.Size([1, 1, 16])


In [16]:
o_slc = torch.zeros(batch_size, t, dim)
for j in range(t):
    Q_mha[:, :, j, :].unsqueeze(dim = 2)
    K_slc_j = K_slc[:, j, :, :, :].repeat(1, heads, 1, 1)
    V_slc_j = V_slc[:, j, :, :, :].repeat(1, heads, 1, 1)
    
    attn_score_j = Q_slc_j @ K_slc_j.transpose(2,3)
    p_slc_j = F.softmax(attn_score_j, dim = -1) 
    # print(p_slc.shape)

    o_slc_j = p_slc_j @ V_slc_j # bs, seq, dim   
    # print(o_slc_j.shape)

    o_slc_j = o_slc_j.transpose(1,2).view(batch_size, 1, dim)
    o_slc[:, j, :] = o_slc_j
    
print(o_slc.shape)

torch.Size([1, 32, 16])


### Token Selection details

1. NSA using GQA, so we have many group(>1)
2. every group has indipendent KV-Selection
3. In group, caluculative n-heads-Q and 1-heads-KV attention
4. In group 1-heads-kv repeat to n-heads-kv, but in NSA kernel, the 1-heads-kv send to SRAM shared memory. this procedure make less meomery asscess.

## sliding window attention

In [17]:
# built sliding window attention
def get_window_mask(seq_len, window):
    mask = torch.ones(seq_len, seq_len)
    mask = torch.tril(mask)
    win_mask = torch.ones(seq_len - window, seq_len - window)
    win_mask = 1.0 - torch.tril(win_mask)
    mask[window:, :seq_len - window] = win_mask
    return mask
print(get_window_mask(7, 3)) # test
window_mask = get_window_mask(t, 8)

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [0., 1., 1., 1., 0., 0., 0.],
        [0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 1., 1., 1., 0.],
        [0., 0., 0., 0., 1., 1., 1.]])


In [18]:
# simplify multihead attention
S = Q @ K.transpose(1,2) / math.sqrt(dim)
S = F.softmax(S, dim = -1)
S = S * window_mask #
print(S)
o_win = S @ V
print(o_win.shape)

tensor([[[1.0644e-15, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0635e-09, 5.6887e-15, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0207e-12, 5.2067e-10, 3.8002e-10,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.3458e-08,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.9893e-18,
          4.5925e-41, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.0657e-10,
          1.3605e-06, 3.3918e-06]]])
torch.Size([1, 32, 16])


## Gated Aggregation

In [19]:
W_gated = torch.randn(dim, 3) # 3: cmp, slc, win
gate = X @ W_gated
gate = F.sigmoid(gate)
print(gate.shape)

torch.Size([1, 32, 3])


In [20]:
o_list = [o_cmp, o_slc, o_win]
o_star = torch.zeros(batch_size, t, dim)
for i in range(3):
    o_star += gate[:, :, i].unsqueeze(2) * o_list[i]
print(o_star.shape)

torch.Size([1, 32, 16])


## stride sletection attention

$$
\mathbf{p}_t^\text{slc}[j] = \sum_{m=0}^{\frac{l'}{d}-1}\sum_{n=0}^{\frac{l}{d} -1} \mathbf{p}_t^\text{cmp}\left[\frac{l'}{d}j+m +n \right],
$$

In [21]:
d = 8
t = 512 + d
l_cmp = 16
l_slc = 8 # from paper：“l‘ denote the selection block size”, or setting l_slc = {4, 8, 16, 32, ...}
m_max = l_slc // d
n_max = l_cmp // d
print(m_max)
print(n_max)

1
2


In [22]:
# original is t token, compress -> t_cmp token.
t_cmp = (t - d) // l_cmp
print(t_cmp)

32


In [23]:
t_cmp = (t - d) // l_cmp

p_cmp = torch.randn(t_cmp)
p_slc = torch.zeros_like(p_cmp)
j_factor = l_slc // d 

for j in range(t_cmp):
    for m in range(m_max):
        for n in range(n_max):
            idx = j_factor * j + m + n
            if idx >= t_cmp:
                continue
            else:
                p_slc[j] += p_cmp[idx]

print(p_slc)

tensor([ 0.4276,  1.5061,  0.4480,  0.7817, -0.2250,  0.4883, -0.0217, -1.8421,
        -2.7458, -3.0163, -2.0566, -0.7802,  0.8661, -0.6864, -0.7723,  0.7143,
        -1.6664, -0.5904,  3.4035,  3.3638, -0.3060, -0.9972, -0.0107,  0.6096,
        -0.0962, -2.1244, -1.7315,  0.7371,  0.3672, -0.5261,  1.1155,  0.6176])
