# GQA (Grouped Query Attention) Forward Shape Trace

In [None]:
import torch
from torch import nn
import math

import sys
sys.path.insert(0, '/root/vermind')

from vermind_models import VerMindConfig, RMSNorm, precompute_freqs_cis, apply_rotary_pos_emb, repeat_kv

In [None]:
class GQAAttention(nn.Module):
    """
    Grouped Query Attention 实现
    - Q (Query) heads: 8
 - KV (Key/Value) heads: 2
 - 每个KV head 复用 4 次给 Q heads (8 / 2 = 4)
 """
    def __init__(self, args: VerMindConfig):
        super().__init__()
        self.num_attention_heads = args.num_attention_heads
        self.num_key_value_heads = args.num_key_value_heads
        self.n_local_heads = args.num_attention_heads
        self.n_local_kv_heads = args.num_key_value_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.hidden_size // args.num_attention_heads
        
        print(f"=== GQA Attention Configuration ===")
        print(f"  num_attention_heads (Q): {self.num_attention_heads}")
        print(f"  num_key_value_heads (K/V): {self.num_key_value_heads}")
        print(f"  n_rep (repetition): {self.n_rep}")
        print(f"  head_dim: {self.head_dim}")
        print(f"  hidden_size: {args.hidden_size}")
        print()
        
        self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
        self.dropout = args.dropout

    def forward(self,
                x: torch.Tensor,
                position_embeddings: tuple = None,
                attention_mask: torch.Tensor = None):
        
        bsz, seq_len, hidden_size = x.shape
        print(f"[Input] x shape: {x.shape}  (batch_size={bsz}, seq_len={seq_len}, hidden_size={hidden_size})")
        
        # ========== 1. Linear Projections ==========
        xq = self.q_proj(x)
        xk = self.k_proj(x)
        xv = self.v_proj(x)
        
        print(f"\n[1] Linear Projections:")
        print(f"    xq (Q after q_proj): {xq.shape}  → [bsz, seq_len, num_heads * head_dim]")
        print(f"    xk (K after k_proj): {xk.shape}  → [bsz, seq_len, num_kv_heads * head_dim]")
        print(f"    xv (V after v_proj): {xv.shape}  → [bsz, seq_len, num_kv_heads * head_dim]")
        
        # ========== 2. Reshape to (batch, seq, n_heads, head_dim) ==========
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        
        print(f"\n[2] Reshape to heads:")
        print(f"    xq: {xq.shape}  → [bsz, seq_len, n_local_heads, head_dim]")
        print(f"    xk: {xk.shape}  → [bsz, seq_len, n_local_kv_heads, head_dim]")
        print(f"    xv: {xv.shape}  → [bsz, seq_len, n_local_kv_heads, head_dim]")
        
        # ========== 3. RoPE (Rotary Position Embedding) ==========
        if position_embeddings is not None:
            cos, sin = position_embeddings
            xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
            print(f"\n[3] RoPE applied:")
            print(f"    xq: {xq.shape}  (unchanged)")
            print(f"    xk: {xk.shape}  (unchanged)")
        
        # ========== 4. Transpose for attention computation ==========
        xq = xq.transpose(1, 2)  # [bsz, n_heads, seq_len, head_dim]
        xk = repeat_kv(xk, self.n_rep).transpose(1, 2)
        xv = repeat_kv(xv, self.n_rep).transpose(1, 2)
        
        print(f"\n[4] Transpose & repeat_kv:")
        print(f"    xq: {xq.shape}  → [bsz, n_local_heads, seq_len, head_dim]")
        print(f"    xk (after repeat_kv): {xk.shape}  → [bsz, n_local_heads, seq_len, head_dim]")
        print(f"    xv (after repeat_kv): {xv.shape}  → [bsz, n_local_heads, seq_len, head_dim]")
        print(f"    Note: KV heads repeated {self.n_rep} times ({self.n_local_kv_heads} → {self.n_local_heads})")
        
        # ========== 5. Scaled Dot-Product Attention ==========
        # scores = Q @ K^T / sqrt(d)
        scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        print(f"\n[5] Attention Scores:")
        print(f"    scores (Q @ K^T): {scores.shape}  → [bsz, n_heads, seq_len, seq_len]")
        
        # Apply causal mask
        causal_mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
        scores = scores + causal_mask
        
        # Apply softmax
        scores = torch.nn.functional.softmax(scores, dim=-1)
        
        # Output = scores @ V
        output = torch.matmul(scores, xv)
        
        print(f"    output (scores @ V): {output.shape}  → [bsz, n_heads, seq_len, head_dim]")
        
        # ========== 6. Reshape and Output Projection ==========
        output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
        
        print(f"\n[6] Reshape output:")
        print(f"    output (transpose + reshape): {output.shape}  → [bsz, seq_len, n_heads * head_dim]")
        
        output = self.o_proj(output)
        
        print(f"\n[7] Output Projection:")
        print(f"    output (after o_proj): {output.shape}  → [bsz, seq_len, hidden_size]")
        print(f"\n{'='*50}")
        
        return output

In [None]:
# ========== Demo ==========
config = VerMindConfig(
    hidden_size=512,
    num_attention_heads=8,   # 8 Query heads
    num_key_value_heads=2,   # 2 Key/Value heads
    num_hidden_layers=1,
    vocab_size=6400,
    dropout=0.0,
    flash_attn=False,
)

attention = GQAAttention(config)

# 准备输入
batch_size = 2
seq_len = 8
hidden_size = 512

x = torch.randn(batch_size, seq_len, hidden_size)

# 准备位置编码
max_seq_len = seq_len
head_dim = config.hidden_size // config.num_attention_heads
freqs_cos, freqs_sin = precompute_freqs_cis(head_dim, max_seq_len)

print("Running GQA forward pass...\n")
output = attention(x, position_embeddings=(freqs_cos, freqs_sin))

## GQA 关键特性总结

| 步骤 | 操作 | Shape 变化 |
|------|------|-----------|
| 1 | Input x | `[bsz, seq_len, hidden_size]` |
| 2 | Q/K/V Linear | Q: `[bsz, seq_len, h*h]` → K/V: `[bsz, seq_len, kv*h]` |
| 3 | Reshape to heads | `[bsz, seq_len, n_heads, head_dim]` |
| 4 | RoPE (可选) | Shape 不变 |
| 5 | transpose + repeat_kv | `[bsz, n_heads, seq_len, head_dim]` |
| 6 | Attention scores | `[bsz, n_heads, seq_len, seq_len]` |
| 7 | Output | `[bsz, seq_len, hidden_size]` |

### 核心优势
- **减少 KV cache**: 2 个 KV heads 替代 8 个，显存占用减少 4x
- **保持质量**: 通过 repeat_kv 复用，每个 Q head 都能访问所有 KV 信息