# Grouped Query Attention

Grouped Query Attention (GQA) 是 Transformer 架构中一种注意力机制的变体，介于多头注意力(MHA)和多查询注意力(MQA)之间。在 GQA 中：
- 查询头(Q)的数量保持与标准多头注意力相同
- 键头(K)和值头(V)的数量减少，并被多个查询头共享
- 这种设计可以在保持较好模型性能的同时减少内存带宽需求



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Tuple
import math

class ModelArgs:
    dim: int = 18               # 嵌入词向量维度为18
    n_layers: int = 1           # 1层Transformer
    n_heads: int = 6            # 6个查询头(Q)
    n_kv_heads: int = 2         # 2个键值头(KV)
    vocab_size: int = -1
    multiple_of: int = 10
    norm_eps: float = 1e-5
    rope_theta: float = 500000
    max_batch_size: int = 2
    max_seq_len: int = 17
    model_parallel_size = 1      # 默认为1(单GPU)

# 每个头的维度：head_dim = dim / n_heads = 18/6 = 3

# Q头的总维度：n_heads * head_dim = 6*3 = 18

# KV头的总维度：n_kv_heads * head_dim = 2*3 = 6

# 每个KV头被复制的次数：n_rep = n_heads / n_kv_heads = 6/2 = 3


In [3]:
# KV头复制函数
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """将KV头复制n_rep次以匹配Q头的数量"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

# 示例
k = torch.randn(1, 7, 2, 3)  # batch=1, seqlen=7, n_kv_heads=2, head_dim=3
repeat_k = repeat_kv(k, 3)   # 将每个KV头复制3次
print(repeat_k.shape)        # torch.Size([1, 7, 6, 3])
print(repeat_k[0,0,:,:])     # 可以看到每个KV头被复制了3次

torch.Size([1, 7, 6, 3])
tensor([[-0.3614,  1.3549,  0.1143],
        [-0.3614,  1.3549,  0.1143],
        [-0.3614,  1.3549,  0.1143],
        [ 0.9436, -0.4128,  0.4709],
        [ 0.9436, -0.4128,  0.4709],
        [ 0.9436, -0.4128,  0.4709]])


In [4]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_kv_heads
        self.model_parallel_size = args.model_parallel_size
        self.n_local_heads = args.n_heads // self.model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        # 线性变换层
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        # KV缓存(用于推理)
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, 
                                  self.n_local_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, 
                                  self.n_local_kv_heads, self.head_dim))

    def forward(self, x: torch.Tensor, start_pos: int, 
               freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.shape
        
        # 计算Q,K,V
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # 更新KV缓存
        self.cache_k[:bsz, start_pos:start_pos+seqlen] = xk
        self.cache_v[:bsz, start_pos:start_pos+seqlen] = xv
        keys = self.cache_k[:bsz, :start_pos+seqlen]
        values = self.cache_v[:bsz, :start_pos+seqlen]

        # 复制KV头以匹配Q头数量
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        # 调整维度顺序并计算注意力
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores += mask
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)

        # 合并多头输出
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attention = Attention(args)
    
    def forward(self, x: torch.Tensor, start_pos: int, 
               freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        return x + self.attention(x, start_pos, freqs_cis, mask)  # 残差连接

In [6]:
config = ModelArgs()
config.model_parallel_size = 2  # 假设使用2个GPU

attn_parallel = Attention(config)
print(f'每个GPU上的Q头数: {attn_parallel.n_local_heads}')  # 3
print(f'每个GPU上的KV头数: {attn_parallel.n_local_kv_heads}')  # 1

每个GPU上的Q头数: 3
每个GPU上的KV头数: 1


In [None]:
分配策略：
```
GPU0: Q1, Q2, Q3, K1, V1
GPU1: Q4, Q5, Q6, K2, V2
```

计算过程：
1. 每个GPU获取自己负责的Q头和KV头
2. GPU0将K1,V1复制3次，GPU1将K2,V2复制3次
3. 各自计算注意力：
   - GPU0: Q1,Q2,Q3 与 K1(复制3份),V1(复制3份) 计算
   - GPU1: Q4,Q5,Q6 与 K2(复制3份),V2(复制3份) 计算
4. 合并结果：GPU1将输出发送到GPU0，GPU0将所有输出拼接并通过Wo线性层


In [7]:
# 初始化配置和模型
config = ModelArgs()
llama_block = TransformerBlock(1, config)

# 创建输入 (batch=1, seqlen=7, dim=18)
x_src = torch.randn(1, 7, 18)

# 前向传播
y = llama_block(x_src, start_pos=0, freqs_cis=None, mask=None)

print(f"输入形状: {x_src.shape}")  # torch.Size([1, 7, 18])
print(f"输出形状: {y.shape}")      # torch.Size([1, 7, 18])

输入形状: torch.Size([1, 7, 18])
输出形状: torch.Size([1, 7, 18])
