In [1]:
import torch
import torch.nn.functional as F

In [2]:
Q = torch.randn(100, 128).cuda().to(torch.bfloat16)
K = torch.randn(100, 128).cuda().to(torch.bfloat16)
V = torch.randn(100, 128).cuda().to(torch.bfloat16)
L = Q.shape[0]
S = K.shape[0]
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()

In [3]:
full_attention = torch.matmul(F.softmax(torch.matmul(Q, K.T) * temp_mask, dim = -1), V)
full_attention.shape

torch.Size([100, 128])

In [4]:
full_attention.argmax(-1)

tensor([ 89,  49,  89,  49,  97,  49,  12,  28,  49,  97,  74,  65,  57,  74,
         65,  74,  63,  74,  22,  74,  74,  70, 101,  63,  74,  74,  28,  30,
         20,  32,  22, 108,  67,  49,  97,  61,  39,  68,  22, 105, 105,  97,
        105,  89,  88,  97,  63, 101,  94, 105,  29,  39,   7,  63, 120, 101,
         97,  65,  29,  70,   8,  74,  19, 111,  20,  60,  33,  52, 101, 113,
         49, 126,  21,  63,  54,  29,  94,  83,  83,  75,  58,  89,  89, 110,
         58,  68,  40,  14,  61,  49,  12,  89,  21,  89, 126, 107,   5,  97,
         60,  93], device='cuda:0')

In [5]:
chunk_size = 5
Q_blocks = torch.chunk(Q, chunk_size)
K_blocks = torch.chunk(K, chunk_size)
V_blocks = torch.chunk(V, chunk_size)

In [6]:
attn_bias_blocks = torch.chunk(temp_mask, chunk_size)
seq_chunk = Q.shape[0] // chunk_size

In [7]:
block_attentions = []
block_maxes = []

Q_block = Q_blocks[0]
attn_bias_block = attn_bias_blocks[0]

for no, (K_block, V_block) in enumerate(zip(K_blocks, V_blocks)):
    # Compute attention scores
    
    attn_bias_b = attn_bias_block[:, no * seq_chunk: (no + 1) * seq_chunk]
    scores = torch.matmul(Q_block, K_block.T) * attn_bias_b

    # Compute block-wise max
    block_max = scores.max(dim=-1, keepdim=True)[0]
    block_maxes.append(block_max)

    # Compute block-wise attention
    block_attention = torch.matmul(F.softmax(scores - block_max, dim=-1), V_block)
    block_attentions.append(block_attention)

# Compute global max
global_max = torch.max(torch.cat(block_maxes, dim=-1), dim=-1, keepdim=True)[0]

# Scale and combine block attentions
scaled_attentions = [
    torch.exp(block_max - global_max) * block_attention
    for block_max, block_attention in zip(block_maxes, block_attentions)
]

output = sum(scaled_attentions)

In [8]:
(torch.sign(full_attention[:output.shape[0]]) == torch.sign(output)).float().mean()

tensor(0.9906, device='cuda:0')

In [9]:
(full_attention[:block_max.shape[0]] - output).mean()

tensor(0.0063, device='cuda:0', dtype=torch.bfloat16)

In [10]:
print(full_attention[:output.shape[0]].argmax(-1), output.argmax(-1))

tensor([89, 49, 89, 49, 97, 49, 12, 28, 49, 97, 74, 65, 57, 74, 65, 74, 63, 74,
        22, 74], device='cuda:0') tensor([34, 49, 89, 49, 97, 49, 12, 28, 49, 97, 74, 65, 57, 74, 65, 74, 63, 74,
        22, 74], device='cuda:0')


In [11]:
output.argmax(-1)

tensor([34, 49, 89, 49, 97, 49, 12, 28, 49, 97, 74, 65, 57, 74, 65, 74, 63, 74,
        22, 74], device='cuda:0')

In [12]:
full_attention[:output.shape[0]].argmax(-1) == output.argmax(-1)

tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
       device='cuda:0')