## MHA attention 操作

In [2]:
import torch
from torch.nn.functional import scaled_dot_product_attention

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 8, 64)
value = torch.randn(1, 256, 8, 64)

output = scaled_dot_product_attention(query, key, value)
print(output.shape) # torch.Size([1, 256, 8, 64])

torch.Size([1, 256, 8, 64])


## 逐步拆解 GQA

In [7]:
import torch

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)

num_head_groups = query.shape[2] // key.shape[2]
print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

4


In [8]:
from einops import rearrange

query = rearrange(query, "b n h d -> b h n d")  # [1, 8, 256, 64]
key = rearrange(key, "b s h d -> b h s d")      # [1, 2, 256, 64]
value = rearrange(value, "b s h d -> b h s d")  # [1, 2, 256, 64]

query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)  # [1, 4, 2, 256, 64]
print(query.shape)

torch.Size([1, 4, 2, 256, 64])


这一步比较关键，实际上进行了 2 步操作：
1. 矩阵乘法 (1, 4, 2, 256, 64) @ (1, 2, 256, 64) --> (1,4,2,256,256)
2. 沿着 g 的维度进行求和操作，最终得到 (1,2,256,256)

In [9]:
from einops import einsum
# g stands for the number of groups
# h stands for the hidden dim
# n and s are equal and stands for sequence length
 
scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
print(scores.shape) # torch.Size([1, 2, 256, 256])

torch.Size([1, 2, 256, 256])


In [11]:
import torch.nn.functional as F

scale = query.size(-1) ** 0.5
attention = F.softmax(scores / scale, dim=-1)

# here we do just a standard matrix multiplication
out = einsum(attention, value, "b h n s, b h s d -> b h n d")

# finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
out = rearrange(out, "b h n d -> b n h d")
print(out.shape) # torch.Size([1, 256, 2, 64])

torch.Size([1, 256, 2, 64])
