## MHA attention 操作

In [2]:
import torch
from torch.nn.functional import scaled_dot_product_attention

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 8, 64)
value = torch.randn(1, 256, 8, 64)

output = scaled_dot_product_attention(query, key, value)
print(output.shape) # torch.Size([1, 256, 8, 64])

torch.Size([1, 256, 8, 64])


## 逐步拆解 GQA

In [7]:
import torch

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)

num_head_groups = query.shape[2] // key.shape[2]
print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

4


In [8]:
from einops import rearrange

query = rearrange(query, "b n h d -> b h n d")  # [1, 8, 256, 64]
key = rearrange(key, "b s h d -> b h s d")      # [1, 2, 256, 64]
value = rearrange(value, "b s h d -> b h s d")  # [1, 2, 256, 64]

query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)  # [1, 4, 2, 256, 64]
print(query.shape)

torch.Size([1, 4, 2, 256, 64])


这一步比较关键，实际上进行了 2 步操作：
1. 矩阵乘法 (1, 4, 2, 256, 64) @ (1, 2, 256, 64) --> (1,4,2,256,256)
2. 沿着 g 的维度进行求和操作，最终得到 (1,2,256,256)

In [9]:
from einops import einsum
# g stands for the number of groups
# h stands for the hidden dim
# n and s are equal and stands for sequence length
 
scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
print(scores.shape) # torch.Size([1, 2, 256, 256])

torch.Size([1, 2, 256, 256])


In [11]:
import torch.nn.functional as F

scale = query.size(-1) ** 0.5
attention = F.softmax(scores / scale, dim=-1)

# here we do just a standard matrix multiplication
out = einsum(attention, value, "b h n s, b h s d -> b h n d")

# finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
out = rearrange(out, "b h n d -> b n h d")
print(out.shape) # torch.Size([1, 256, 2, 64])

torch.Size([1, 256, 2, 64])


接下来手动实现一个 GQA 类

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GQA(nn.Module):
    def __init__(self, dim, num_heads, num_groups=None):
        super(GQA, self).__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_groups = num_groups if num_groups else num_heads

        assert self.num_heads % self.num_groups == 0, "num_heads must be divisible by num_groups"
        self.w_query = nn.Linear(dim, dim)
        self.w_key = nn.Linear(dim, self.num_groups * self.head_dim)
        self.w_value = nn.Linear(dim, self.num_groups * self.head_dim)
        self.o_proj = nn.Linear(dim, dim)

    def forward(self, query, key, value):
        # query, key, value: (batch_size, seq_len, dim)
        batch_size, seq_len, _ = query.shape

        # (batch_size, seq_len, num_heads, head_dim)
        query = self.w_query(query).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)

        # (batch_size, seq_len, num_groups, head_dim)
        key = self.w_key(key).view(batch_size, seq_len, self.num_groups, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_groups, seq_len, head_dim)
        value = self.w_value(value).view(batch_size, seq_len, self.num_groups, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_groups, seq_len, head_dim)

        expand_ratio = self.num_heads // self.num_groups
        key = key.unsqueeze(2).expand(-1, -1, expand_ratio, -1, -1).reshape(batch_size, self.num_heads, seq_len, self.head_dim) # (batch_size, num_heads, seq_len, head_dim)
        value = value.unsqueeze(2).expand(-1, -1, expand_ratio, -1, -1).reshape(batch_size, self.num_heads, seq_len, self.head_dim) # (batch_size, num_heads, seq_len, head_dim)

        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.head_dim) # (batch_size, num_heads, seq_len, seq_len)

        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device) # upper triangular mask
        scores = scores.masked_fill(mask == 1, float('-inf'))
        attn = F.softmax(scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
        attn = F.dropout(attn, p=0.1, training=self.training) # dropout
        attn_out = torch.matmul(attn, value) # (batch_size, num_heads, seq_len, head_dim)
        attn_out = attn_out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, dim)
        out = self.o_proj(attn_out)

        return out


In [11]:
# Example usage
embed_dim = 512
num_heads = 8
num_groups = 2 

gqa = GQA(embed_dim, num_heads, num_groups).eval()

batch_size = 4
seq_len = 32
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)

output = gqa(query, key, value)
print(output.shape)  # Should be (batch_size, seq_len, embed_dim)


torch.Size([4, 32, 512])


In [12]:
import torch
seq_len = 4
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
print(mask) # upper triangular mask

cache_len = 2
mask_w_cache = torch.triu(torch.ones(1, 1 + cache_len), diagonal=1 + cache_len)
print(mask_w_cache) # upper triangular mask with cache length

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])
tensor([[0., 0., 0.]])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GQA_W_KVcache(nn.Module):
    def __init__(self, dim, num_heads, num_groups=None):
        super(GQA, self).__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_groups = num_groups if num_groups else num_heads

        assert self.num_heads % self.num_groups == 0, "num_heads must be divisible by num_groups"
        self.w_query = nn.Linear(dim, dim)
        self.w_key = nn.Linear(dim, self.num_groups * self.head_dim)
        self.w_value = nn.Linear(dim, self.num_groups * self.head_dim)
        self.o_proj = nn.Linear(dim, dim)

        self.k_cache = None
        self.v_cache = None
        self.use_cache = False

    def forward(self, query, key, value):
        # query, key, value: (batch_size, seq_len, dim)
        batch_size, seq_len, _ = query.shape

        # (batch_size, seq_len, num_heads, head_dim)
        query = self.w_query(query).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)

        # (batch_size, seq_len, num_groups, head_dim)
        key = self.w_key(key).view(batch_size, seq_len, self.num_groups, self.head_dim)
        value = self.w_value(value).view(batch_size, seq_len, self.num_groups, self.head_dim)

        if self.use_cache:
            if self.k_cache is None:
                self.k_cache = key
                self.v_cache = value
            else:
                key = torch.cat([self.k_cache, key], dim=1)
                value = torch.cat([self.v_cache, value], dim=1)

        key = key.permute(0, 2, 1, 3) # (batch_size, num_groups, seq_len, head_dim)
        value = value.permute(0, 2, 1, 3) # (batch_size, num_groups, seq_len, head_dim)

        expand_ratio = self.num_heads // self.num_groups
        key = key.unsqueeze(2).expand(-1, -1, expand_ratio, -1, -1).reshape(batch_size, self.num_heads, -1, self.head_dim) # (batch_size, num_heads, seq_len, head_dim)
        value = value.unsqueeze(2).expand(-1, -1, expand_ratio, -1, -1).reshape(batch_size, self.num_heads, -1, self.head_dim) # (batch_size, num_heads, seq_len, head_dim)

        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.head_dim) # (batch_size, num_heads, seq_len, seq_len)

        if seq_len > 1:
            # Create a mask for the upper triangular part of the attention scores
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(query.device) # upper triangular mask
            scores = scores.masked_fill(mask == 1, float('-inf'))
        attn = F.softmax(scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
        attn = F.dropout(attn, p=0.1, training=self.training) # dropout
        attn_out = torch.matmul(attn, value) # (batch_size, num_heads, seq_len, head_dim)
        attn_out = attn_out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, dim)
        out = self.o_proj(attn_out)

        return out
