In [1]:
import torch
import torch.nn.functional as F

In [2]:
Q = torch.randn(100, 128).cuda().to(torch.bfloat16)
K = torch.randn(100, 128).cuda().to(torch.bfloat16)
V = torch.randn(100, 128).cuda().to(torch.bfloat16)

In [3]:
full_attention = torch.matmul(F.softmax(torch.matmul(Q, K.T), dim = -1), V)
full_attention.shape

torch.Size([100, 128])

In [4]:
chunk_size = 2
Q_blocks = torch.chunk(Q, chunk_size)
K_blocks = torch.chunk(K, chunk_size)
V_blocks = torch.chunk(V, chunk_size)

In [5]:
Q_block = Q_blocks[0]

In [6]:
block_attentions = []
block_maxes = []

for K_block, V_block in zip(K_blocks, V_blocks):
    # Compute attention scores
    scores = torch.matmul(Q_block, K_block.T)

    # Compute block-wise max
    block_max = scores.max(dim=-1, keepdim=True)[0]
    block_maxes.append(block_max)

    # Compute block-wise attention
    block_attention = torch.matmul(F.softmax(scores - block_max, dim=-1), V_block)
    block_attentions.append(block_attention)

# Compute global max
global_max = torch.max(torch.cat(block_maxes, dim=-1), dim=-1, keepdim=True)[0]

# Scale and combine block attentions
scaled_attentions = [
    torch.exp(block_max - global_max) * block_attention
    for block_max, block_attention in zip(block_maxes, block_attentions)
]

output = sum(scaled_attentions)

In [13]:
block_max.shape

torch.Size([50, 1])

In [7]:
output

tensor([[ 0.9531,  0.6719, -1.1797,  ...,  0.8594, -0.8008,  0.1621],
        [ 1.3672, -1.2109, -0.4023,  ...,  2.2188, -1.2812, -1.1328],
        [ 0.8438,  0.2695,  0.0776,  ...,  0.2441,  0.5938,  0.2148],
        ...,
        [-1.8750,  0.6211, -1.1719,  ..., -1.1641, -1.5000,  0.1094],
        [ 0.5312,  0.5664,  2.1250,  ..., -0.6875,  0.5312, -1.2578],
        [ 0.5469, -1.2422,  2.0469,  ..., -0.8984,  0.5156, -0.8438]],
       device='cuda:0', dtype=torch.bfloat16)

In [8]:
full_attention[:50]

tensor([[ 0.9414,  0.6641, -1.1719,  ...,  0.8477, -0.7930,  0.1602],
        [ 1.3672, -1.2109, -0.4023,  ...,  2.2188, -1.2812, -1.1328],
        [ 0.8438,  0.2695,  0.0781,  ...,  0.2441,  0.5938,  0.2148],
        ...,
        [-1.1641,  0.3848, -0.7266,  ..., -0.7227, -0.9297,  0.0659],
        [ 0.5312,  0.5664,  2.1250,  ..., -0.6875,  0.5312, -1.2578],
        [ 0.4629, -1.0547,  1.7344,  ..., -0.7617,  0.4375, -0.7227]],
       device='cuda:0', dtype=torch.bfloat16)

In [14]:
(torch.sign(full_attention[:output.shape[0]]) == torch.sign(output)).float().mean()

tensor(0.9958, device='cuda:0')

In [10]:
(full_attention[:50] - output).mean()

tensor(0.0015, device='cuda:0', dtype=torch.bfloat16)

In [17]:
print(full_attention[:output.shape[0]].argmax(-1), output.argmax(-1))

tensor([122,  84,  27,  20,  98,  60,  36,  65,  39,  48,  31,  91,  48,  69,
         80,  98,  59, 121,   0,  24,  42,  67,  76,  58,  36,  34,  79,   1,
         57,  99,   9,  47,  77, 110,   9,   9, 119,   9,  34,  27,   6,  37,
        104, 121, 103, 123,   0,  56,  67, 104], device='cuda:0') tensor([122,  84,  27,  20,  98,  60,  36,  65,  39,  48,  31,  91,  48,  69,
         80,  98,  59, 121,   0,  24,  42,  39,  76,  58,  36,  34,  79,   1,
         57,  40,   9,  47,  77, 110,   9,   9, 119,   9,  34,  27,   6,  37,
        104, 121, 103, 123,   0,  56,  67, 104], device='cuda:0')


In [16]:
output.argmax(-1)

tensor([122,  84,  27,  20,  98,  60,  36,  65,  39,  48,  31,  91,  48,  69,
         80,  98,  59, 121,   0,  24,  42,  39,  76,  58,  36,  34,  79,   1,
         57,  40,   9,  47,  77, 110,   9,   9, 119,   9,  34,  27,   6,  37,
        104, 121, 103, 123,   0,  56,  67, 104], device='cuda:0')