# Native-Sparse-Attention

author: [dhcode-cpp](https://github.com/dhcode-cpp)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Config

In [208]:
t = 32 # token ids 
l = 8 # block
d = 8 # sliding stride
block_nums = t // l
dim = 16 # embeddin dimension
heads = 4
head_dim = dim//heads
batch_size = 1

In [72]:
X = torch.randn(batch_size, t, dim)

Wq = torch.randn(dim, dim)
Wk = torch.randn(dim, dim)
Wv = torch.randn(dim, dim)

Q = X @ Wq
K = X @ Wk
V = X @ Wv

# skip apply rope

In [82]:
print(Q.shape)
print(K.shape)
print(V.shape)

torch.Size([1, 32, 16])
torch.Size([1, 32, 16])
torch.Size([1, 32, 16])


## Attention with different KV-len

## Token Compression

In [234]:
d = 4
max_idx = round(( t - l ) / d)
print(max_idx)
print(torch.arange(max_idx) * d + 1)
print(torch.arange(max_idx) * d + l)

6
tensor([ 1,  5,  9, 13, 17, 21])
tensor([ 8, 12, 16, 20, 24, 28])


In [235]:
d = l
max_idx = round(( t ) / d)
print(max_idx)
print(torch.arange(max_idx) * d + 1)
print(torch.arange(max_idx) * d + l)

4
tensor([ 1,  9, 17, 25])
tensor([ 8, 16, 24, 32])


In [92]:
W_K_cmp = torch.randn(l, 1)
W_V_cmp = torch.randn(l, 1)
W_pe = torch.randn(l, dim)

In [93]:
K[:, i * d + 0: i * d + l  , :].shape

torch.Size([1, 8, 16])

In [94]:
K_cmp = []
V_cmp = []
for i in range(max_idx):
    cur_K = K[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
    cur_V = V[:, i * d + 0: i * d + l , :] + W_pe.unsqueeze(0)
    cur_K = cur_K.transpose(1, 2) @ W_K_cmp 
    cur_V = cur_V.transpose(1, 2) @ W_V_cmp
    K_cmp.append(cur_K)
    V_cmp.append(cur_V)

K_cmp = torch.cat(K_cmp, dim = 2).transpose(1,2)
V_cmp = torch.cat(V_cmp, dim = 2).transpose(1,2)
print(K_cmp.shape)
print(V_cmp.shape)

torch.Size([1, 4, 16])
torch.Size([1, 4, 16])


In [121]:
# 多头注意力版本
Q_mha = Q.view(1, t, heads, head_dim).transpose(1,2)
K_cmp_mha = K_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
V_cmp_mha = V_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
score_cmp = Q_mha @ K_cmp_mha.transpose(2,3) # bs, head, q_len, k_cmp_len
print(score_cmp.shape)

torch.Size([1, 4, 32, 4])


In [224]:
p_cmp = F.softmax(score_cmp, dim = -1) 
o_cmp = p_cmp @ V_cmp_mha
print(o_cmp.shape)

o_cmp = o_cmp.transpose(2, 1).reshape(batch_size, t, dim)
print(o_cmp.shape)

torch.Size([1, 4, 32, 4])
torch.Size([1, 32, 16])


## Token Selection

In [236]:
print(p_cmp.shape)
p_slc = p_cmp.sum(dim = 1)
print(p_slc.shape)

torch.Size([1, 4, 32, 4])
torch.Size([1, 32, 4])


In [128]:
select_top_k = 2
value, idx = torch.topk(p_slc, dim = 2, k = select_top_k)
print(key[0,0,:])
key.shape

tensor([3, 0])


torch.Size([1, 32, 2])

In [192]:
idx_slc_start = idx * d
idx_slc_end = idx * d + l
K_slc = torch.randn(batch_size, d * select_top_k, dim)
V_slc = torch.randn(batch_size, d * select_top_k, dim)
for i in range(batch_size):
    for j in range(t):
        for k in range(select_top_k):
            K_slc[i, k * d : k * d + l, :] = K[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
            V_slc[i, k * d : k * d + l, :] = V[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
print(K_slc.shape)
print(V_slc.shape)

torch.Size([1, 16, 16])
torch.Size([1, 16, 16])


In [193]:
V_slc_mha = V_slc.view(batch_size, select_top_k * d, heads, head_dim).transpose(1,2)
V_slc = V_slc_mha.sum(dim = 1, keepdim = True)
print(V_slc.shape)

K_slc_mha = K_slc.view(1, select_top_k * d, heads, head_dim).transpose(1,2)
K_slc = K_slc_mha.sum(dim = 1, keepdim = True)
print(V_slc.shape)

torch.Size([1, 1, 16, 4])
torch.Size([1, 1, 16, 4])


In [216]:
# 多头注意力版本
score_cls = Q_mha @ K_slc.transpose(2,3).repeat(1, heads, 1, 1) # bs, head, q_len, 8282
print(score_cls.shape)
p_slc = F.softmax(score_cls, dim = -1) 
print(p_cls.shape)

V_slc_onehead = V_slc.repeat(1, heads, 1, 1)
print(V_slc_onehead.shape)

o_slc = p_slc @ V_slc_onehead # bs, seq, dim   
print(o_cls.shape)

torch.Size([1, 4, 32, 16])
torch.Size([1, 4, 32, 16])
torch.Size([1, 4, 16, 4])
torch.Size([1, 32, 16])


In [217]:
o_slc = o_slc.transpose(2, 1).reshape(batch_size, t, dim)
print(o_cls.shape)

torch.Size([1, 32, 16])


## window attention

In [212]:
window = 8
t_idxs = torch.arange(t)
print(t_idxs)

K_win = torch.randn(batch_size, window, dim)
V_win = torch.randn(batch_size, window, dim)

for i in range(batch_size):
    idx_start = i - window
    if idx_start < 0:
        idx_start = 0  
    K_win[i, :, :] = K[i, idx_start : idx_start + window , :]
    V_win[i, :, :] = V[i, idx_start : idx_start + window , :]
print(K_win.shape)
print(V_win.shape)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
torch.Size([1, 8, 16])
torch.Size([1, 8, 16])


In [215]:
# 我们简化注意力计算
S = Q @ K_win.transpose(1,2)
o_win = S @ V_win
print(o_win.shape)

torch.Size([1, 32, 16])


## Gated Aggregation

In [237]:
W_gated = torch.randn(dim, 3) # 3: cmp, slc, win
gate = X @ W_gated
gate = F.sigmoid(gate)
print(gate.shape)

torch.Size([1, 32, 3])


In [231]:
o_list = [o_cmp, o_slc, o_win]
o_star = torch.zeros(batch_size, t, dim)
for i in range(3):
    o_star += gate[:, :, i].unsqueeze(2) * o_list[i]
print(o_star.shape)

torch.Size([1, 32, 16])
