# 3. 多头注意力机制 (Multi-Head Attention)

在前面的教程中，我们学习了自注意力机制。现在我们将学习**多头注意力机制**，这是Transformer架构中的一个关键改进，它允许模型同时关注不同类型的信息。

## 3.1 为什么需要多头注意力？

单头自注意力有一个局限性：它只能学习一种类型的注意力模式。但是在实际应用中，我们希望模型能够同时关注不同方面的信息：

1. **语法关系**：主语、谓语、宾语之间的关系
2. **语义关系**：同义词、反义词、相关概念
3. **位置关系**：距离远近、前后顺序
4. **其他关系**：情感色彩、抽象概念等

**多头注意力的核心思想**：让模型并行地学习多种不同的注意力模式，然后将这些信息融合起来。

![多头注意力示意图](images/multi_head_attention.png)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional

# 设置随机种子和图表样式
torch.manual_seed(42)
np.random.seed(42)
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

print(f"PyTorch版本: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

## 3.2 多头注意力的数学公式

多头注意力的计算过程可以分为以下步骤：

1. **投影到多个子空间**：
   $$Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V$$
   其中 $i = 1, 2, ..., h$（h是头的数量）

2. **计算每个头的注意力**：
   $$\text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i$$

3. **拼接和投影**：
   $$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

其中：
- $d_k = d_{model} / h$ 是每个头的维度
- $W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{model} \times d_k}$
- $W^O \in \mathbb{R}^{d_{model} \times d_{model}}$

In [None]:
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制的实现
    """
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 定义线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # Dropout层
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
                mask: Optional[torch.Tensor] = None):
        """
        前向传播
        
        Args:
            query: [batch_size, seq_len, d_model]
            key: [batch_size, seq_len, d_model]
            value: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len] 或 None
        
        Returns:
            output: [batch_size, seq_len, d_model]
            attention_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        batch_size, seq_len, d_model = query.size()
        
        # 1. 线性变换得到Q、K、V
        Q = self.W_q(query)  # [batch_size, seq_len, d_model]
        K = self.W_k(key)    # [batch_size, seq_len, d_model]
        V = self.W_v(value)  # [batch_size, seq_len, d_model]
        
        # 2. 重塑为多头形式
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # 形状: [batch_size, num_heads, seq_len, d_k]
        
        # 3. 计算缩放点积注意力
        attention_output, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )
        
        # 4. 拼接多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # 5. 最终的线性变换
        output = self.W_o(attention_output)
        
        return output, attention_weights
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        缩放点积注意力
        
        Args:
            Q, K, V: [batch_size, num_heads, seq_len, d_k]
            mask: [batch_size, 1, seq_len, seq_len] 或 None
        
        Returns:
            output: [batch_size, num_heads, seq_len, d_k]
            attention_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 应用掩码（如果提供）
        if mask is not None:
            if mask.dim() == 3:  # [batch_size, seq_len, seq_len]
                mask = mask.unsqueeze(1)  # [batch_size, 1, seq_len, seq_len]
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 计算注意力权重
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 应用注意力权重
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# 测试多头注意力
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

# 创建随机输入
x = torch.randn(batch_size, seq_len, d_model)

# 创建多头注意力层
mha = MultiHeadAttention(d_model, num_heads)

# 自注意力：query、key、value都是同一个输入
output, attention_weights = mha(x, x, x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {attention_weights.shape}")
print(f"每个头的维度: {mha.d_k}")
print(f"头数: {mha.num_heads}")

## 3.3 可视化多头注意力

让我们可视化不同头的注意力模式，观察它们学到的不同关注点：

In [None]:
def visualize_multi_head_attention(attention_weights, num_heads_to_show=4, 
                                   tokens=None, title="多头注意力可视化"):
    """
    可视化多头注意力权重
    
    Args:
        attention_weights: [batch_size, num_heads, seq_len, seq_len]
        num_heads_to_show: 显示的头数
        tokens: 可选的token列表
        title: 图表标题
    """
    # 取第一个批次的数据
    weights = attention_weights[0].detach().numpy()  # [num_heads, seq_len, seq_len]
    
    num_heads_to_show = min(num_heads_to_show, weights.shape[0])
    
    # 创建子图
    fig, axes = plt.subplots(2, (num_heads_to_show + 1) // 2, 
                             figsize=(5 * ((num_heads_to_show + 1) // 2), 8))
    
    if num_heads_to_show <= 2:
        axes = axes.reshape(2, -1)
    
    for i in range(num_heads_to_show):
        row = i // ((num_heads_to_show + 1) // 2)
        col = i % ((num_heads_to_show + 1) // 2)
        
        # 绘制热力图
        if tokens is not None:
            sns.heatmap(weights[i], annot=True, fmt='.2f', cmap='Blues',
                       xticklabels=tokens, yticklabels=tokens, ax=axes[row, col])
        else:
            sns.heatmap(weights[i], annot=True, fmt='.2f', cmap='Blues', ax=axes[row, col])
        
        axes[row, col].set_title(f'头 {i+1}')
        axes[row, col].set_xlabel('被关注位置 (Key)')
        axes[row, col].set_ylabel('查询位置 (Query)')
    
    # 隐藏多余的子图
    for i in range(num_heads_to_show, axes.size):
        row = i // ((num_heads_to_show + 1) // 2)
        col = i % ((num_heads_to_show + 1) // 2)
        axes[row, col].set_visible(False)
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# 可视化多头注意力权重
visualize_multi_head_attention(attention_weights, num_heads_to_show=4,
                              title="8头注意力机制的前4个头")

## 3.4 分析不同头的注意力模式

让我们分析不同头学到的注意力模式的特点：

In [None]:
def analyze_attention_heads(attention_weights, head_names=None):
    """
    分析不同头的注意力模式特征
    """
    # 取第一个批次的数据
    weights = attention_weights[0].detach().numpy()  # [num_heads, seq_len, seq_len]
    num_heads, seq_len, _ = weights.shape
    
    print("=== 多头注意力模式分析 ===")
    print(f"头数: {num_heads}, 序列长度: {seq_len}")
    print()
    
    # 为每个头计算特征指标
    head_features = []
    
    for i in range(num_heads):
        head_weight = weights[i]
        
        # 计算特征指标
        diagonal_mean = np.mean(np.diag(head_weight))  # 自注意力强度
        off_diagonal_mean = np.mean(head_weight - np.diag(np.diag(head_weight)))  # 交叉注意力
        entropy = -np.sum(head_weight * np.log(head_weight + 1e-9), axis=-1).mean()  # 注意力熵
        max_attention = np.max(head_weight)  # 最大注意力值
        
        # 局部性分析：计算注意力是否集中在邻近位置
        locality_score = 0
        for row in range(seq_len):
            for col in range(seq_len):
                distance = abs(row - col)
                if distance <= 2:  # 考虑距离2以内的邻近位置
                    locality_score += head_weight[row, col]
        locality_score /= seq_len
        
        head_features.append({
            'head': i + 1,
            'self_attention': diagonal_mean,
            'cross_attention': off_diagonal_mean,
            'entropy': entropy,
            'max_attention': max_attention,
            'locality': locality_score
        })
        
        # 确定头的类型
        if diagonal_mean > 0.3:
            head_type = "自关注型"
        elif locality_score > 0.5:
            head_type = "局部型"
        elif entropy > 2.0:
            head_type = "全局型"
        else:
            head_type = "混合型"
        
        print(f"头 {i+1} ({head_type}):")
        print(f"  自注意力: {diagonal_mean:.3f}")
        print(f"  交叉注意力: {off_diagonal_mean:.3f}")
        print(f"  注意力熵: {entropy:.3f}")
        print(f"  局部性得分: {locality_score:.3f}")
        print(f"  最大注意力: {max_attention:.3f}")
        print()
    
    return head_features

# 分析注意力头
head_features = analyze_attention_heads(attention_weights)

## 3.5 对比单头 vs 多头注意力

让我们直接比较单头注意力和多头注意力的效果：

In [None]:
# 单头注意力类（从之前的教程）
class SingleHeadAttention(nn.Module):
    def __init__(self, d_model):
        super(SingleHeadAttention, self).__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        output = self.W_o(output)
        
        return output, attention_weights.unsqueeze(1)  # 添加头维度以便比较

# 对比实验
d_model = 64  # 使用较小的维度便于可视化
seq_len = 8
batch_size = 1

# 创建测试输入
test_input = torch.randn(batch_size, seq_len, d_model)

# 单头注意力
single_head = SingleHeadAttention(d_model)
single_output, single_weights = single_head(test_input)

# 多头注意力
multi_head = MultiHeadAttention(d_model, num_heads=4)
multi_output, multi_weights = multi_head(test_input, test_input, test_input)

print(f"单头注意力输出形状: {single_output.shape}")
print(f"多头注意力输出形状: {multi_output.shape}")
print(f"单头注意力权重形状: {single_weights.shape}")
print(f"多头注意力权重形状: {multi_weights.shape}")

# 可视化对比
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# 单头注意力
sns.heatmap(single_weights[0, 0].detach().numpy(), annot=True, fmt='.2f', 
           cmap='Blues', ax=axes[0, 0])
axes[0, 0].set_title('单头注意力')

# 多头注意力的前两个头
for i in range(2):
    sns.heatmap(multi_weights[0, i].detach().numpy(), annot=True, fmt='.2f',
               cmap='Blues', ax=axes[0, i+1])
    axes[0, i+1].set_title(f'多头注意力 - 头{i+1}')

# 平均注意力权重对比
single_avg = single_weights[0, 0].detach().numpy()
multi_avg = multi_weights[0].mean(dim=0).detach().numpy()

sns.heatmap(single_avg, annot=True, fmt='.2f', cmap='Reds', ax=axes[1, 0])
axes[1, 0].set_title('单头注意力权重')

sns.heatmap(multi_avg, annot=True, fmt='.2f', cmap='Reds', ax=axes[1, 1])
axes[1, 1].set_title('多头注意力平均权重')

# 差异可视化
diff = multi_avg - single_avg
sns.heatmap(diff, annot=True, fmt='.2f', cmap='RdBu_r', center=0, ax=axes[1, 2])
axes[1, 2].set_title('差异 (多头 - 单头)')

plt.tight_layout()
plt.show()

## 3.6 多头注意力的变体

让我们实现几种多头注意力的变体：

In [None]:
class MultiQueryAttention(nn.Module):
    """
    多查询注意力（Multi-Query Attention）
    所有头共享同一个Key和Value，但有独立的Query
    """
    def __init__(self, d_model, num_heads):
        super(MultiQueryAttention, self).__init__()
        
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 多个Query投影，但只有一个Key和Value投影
        self.W_q = nn.Linear(d_model, d_model)  # 多个Query
        self.W_k = nn.Linear(d_model, self.d_k)  # 单个Key
        self.W_v = nn.Linear(d_model, self.d_k)  # 单个Value
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value):
        batch_size, seq_len, _ = query.size()
        
        # 多个Query
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 单个Key和Value，广播到所有头
        K = self.W_k(key).unsqueeze(1).expand(batch_size, self.num_heads, seq_len, self.d_k)
        V = self.W_v(value).unsqueeze(1).expand(batch_size, self.num_heads, seq_len, self.d_k)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        # 拼接和投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        return output, attention_weights

class GroupedQueryAttention(nn.Module):
    """
    分组查询注意力（Grouped Query Attention）
    将头分为几组，每组共享Key和Value
    """
    def __init__(self, d_model, num_heads, num_kv_heads):
        super(GroupedQueryAttention, self).__init__()
        
        assert d_model % num_heads == 0
        assert num_heads % num_kv_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.d_k = d_model // num_heads
        self.d_kv = d_model // num_kv_heads
        self.group_size = num_heads // num_kv_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.d_kv * num_kv_heads)
        self.W_v = nn.Linear(d_model, self.d_kv * num_kv_heads)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value):
        batch_size, seq_len, _ = query.size()
        
        # Query: [batch_size, num_heads, seq_len, d_k]
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Key和Value: [batch_size, num_kv_heads, seq_len, d_kv]
        K = self.W_k(key).view(batch_size, seq_len, self.num_kv_heads, self.d_kv).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.num_kv_heads, self.d_kv).transpose(1, 2)
        
        # 将Key和Value复制到对应的Query组
        K = K.repeat_interleave(self.group_size, dim=1)
        V = V.repeat_interleave(self.group_size, dim=1)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        # 拼接和投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        return output, attention_weights

# 测试不同的注意力变体
d_model = 64
num_heads = 8
seq_len = 6
batch_size = 1

test_input = torch.randn(batch_size, seq_len, d_model)

# 标准多头注意力
standard_mha = MultiHeadAttention(d_model, num_heads)
standard_out, standard_weights = standard_mha(test_input, test_input, test_input)

# 多查询注意力
mqa = MultiQueryAttention(d_model, num_heads)
mqa_out, mqa_weights = mqa(test_input, test_input, test_input)

# 分组查询注意力
gqa = GroupedQueryAttention(d_model, num_heads, num_kv_heads=2)
gqa_out, gqa_weights = gqa(test_input, test_input, test_input)

print("注意力变体对比:")
print(f"标准多头注意力参数量: {sum(p.numel() for p in standard_mha.parameters()):,}")
print(f"多查询注意力参数量: {sum(p.numel() for p in mqa.parameters()):,}")
print(f"分组查询注意力参数量: {sum(p.numel() for p in gqa.parameters()):,}")

print(f"\n输出形状对比:")
print(f"标准多头: {standard_out.shape}")
print(f"多查询: {mqa_out.shape}")
print(f"分组查询: {gqa_out.shape}")

## 3.7 注意力头的多样性分析

让我们分析多头注意力中不同头之间的多样性：

In [None]:
def analyze_head_diversity(attention_weights):
    """
    分析注意力头之间的多样性
    """
    # 取第一个批次的数据
    weights = attention_weights[0].detach().numpy()  # [num_heads, seq_len, seq_len]
    num_heads, seq_len, _ = weights.shape
    
    print("=== 注意力头多样性分析 ===")
    
    # 计算头之间的相似性矩阵
    similarity_matrix = np.zeros((num_heads, num_heads))
    
    for i in range(num_heads):
        for j in range(num_heads):
            # 使用余弦相似度
            flat_i = weights[i].flatten()
            flat_j = weights[j].flatten()
            
            dot_product = np.dot(flat_i, flat_j)
            norm_i = np.linalg.norm(flat_i)
            norm_j = np.linalg.norm(flat_j)
            
            similarity_matrix[i, j] = dot_product / (norm_i * norm_j)
    
    # 可视化相似性矩阵
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # 相似性热力图
    sns.heatmap(similarity_matrix, annot=True, fmt='.3f', cmap='RdYlBu_r',
                xticklabels=[f'头{i+1}' for i in range(num_heads)],
                yticklabels=[f'头{i+1}' for i in range(num_heads)], ax=ax1)
    ax1.set_title('注意力头相似性矩阵')
    
    # 多样性指标
    off_diagonal = similarity_matrix[~np.eye(num_heads, dtype=bool)]
    avg_similarity = np.mean(off_diagonal)
    diversity_score = 1 - avg_similarity  # 多样性得分
    
    # 绘制相似性分布
    ax2.hist(off_diagonal, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax2.axvline(avg_similarity, color='red', linestyle='--', 
                label=f'平均相似性: {avg_similarity:.3f}')
    ax2.set_xlabel('头间相似性')
    ax2.set_ylabel('频次')
    ax2.set_title('头间相似性分布')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"平均头间相似性: {avg_similarity:.3f}")
    print(f"多样性得分: {diversity_score:.3f}")
    print(f"最相似的头对: 头{np.unravel_index(np.argmax(similarity_matrix - np.eye(num_heads)), similarity_matrix.shape)}")
    
    # 计算每个头的独特性
    uniqueness_scores = []
    for i in range(num_heads):
        others_similarity = [similarity_matrix[i, j] for j in range(num_heads) if i != j]
        uniqueness = 1 - np.mean(others_similarity)
        uniqueness_scores.append(uniqueness)
    
    print("\n各头的独特性得分:")
    for i, score in enumerate(uniqueness_scores):
        print(f"头{i+1}: {score:.3f}")
    
    return similarity_matrix, diversity_score

# 分析之前计算的多头注意力
similarity_matrix, diversity_score = analyze_head_diversity(attention_weights)

## 3.8 多头注意力的计算效率

让我们分析多头注意力的计算效率：

In [None]:
import time

def benchmark_attention_variants():
    """
    基准测试不同注意力变体的性能
    """
    d_model = 512
    seq_len = 128
    batch_size = 16
    num_heads = 8
    
    # 创建测试数据
    test_data = torch.randn(batch_size, seq_len, d_model)
    
    # 创建不同的注意力模型
    models = {
        '单头注意力': SingleHeadAttention(d_model),
        '标准多头注意力': MultiHeadAttention(d_model, num_heads),
        '多查询注意力': MultiQueryAttention(d_model, num_heads),
        '分组查询注意力': GroupedQueryAttention(d_model, num_heads, num_kv_heads=2)
    }
    
    results = {}
    
    print("=== 注意力变体性能基准测试 ===")
    print(f"测试配置: batch_size={batch_size}, seq_len={seq_len}, d_model={d_model}")
    print()
    
    for name, model in models.items():
        model.eval()
        
        # 预热
        for _ in range(5):
            if name == '单头注意力':
                _ = model(test_data)
            else:
                _ = model(test_data, test_data, test_data)
        
        # 正式测试
        num_runs = 50
        start_time = time.time()
        
        for _ in range(num_runs):
            if name == '单头注意力':
                output, weights = model(test_data)
            else:
                output, weights = model(test_data, test_data, test_data)
        
        end_time = time.time()
        avg_time = (end_time - start_time) / num_runs * 1000  # 转换为毫秒
        
        # 计算参数量
        num_params = sum(p.numel() for p in model.parameters())
        
        results[name] = {
            'time': avg_time,
            'params': num_params,
            'output_shape': output.shape
        }
        
        print(f"{name}:")
        print(f"  平均推理时间: {avg_time:.2f} ms")
        print(f"  参数量: {num_params:,}")
        print(f"  输出形状: {output.shape}")
        print()
    
    # 可视化结果
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    names = list(results.keys())
    times = [results[name]['time'] for name in names]
    params = [results[name]['params'] for name in names]
    
    # 推理时间对比
    bars1 = ax1.bar(names, times, color=['red', 'blue', 'green', 'orange'])
    ax1.set_ylabel('推理时间 (ms)')
    ax1.set_title('不同注意力机制的推理时间对比')
    ax1.tick_params(axis='x', rotation=45)
    
    # 在柱状图上添加数值
    for bar, time_val in zip(bars1, times):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{time_val:.1f}', ha='center', va='bottom')
    
    # 参数量对比
    bars2 = ax2.bar(names, [p/1000 for p in params], color=['red', 'blue', 'green', 'orange'])
    ax2.set_ylabel('参数量 (K)')
    ax2.set_title('不同注意力机制的参数量对比')
    ax2.tick_params(axis='x', rotation=45)
    
    # 在柱状图上添加数值
    for bar, param_val in zip(bars2, params):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                f'{param_val/1000:.1f}K', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return results

# 运行基准测试
benchmark_results = benchmark_attention_variants()

## 3.9 练习和实验

### 练习1：实现位置感知的多头注意力

In [None]:
class PositionalMultiHeadAttention(nn.Module):
    """
    位置感知的多头注意力
    在计算注意力时考虑位置信息
    """
    def __init__(self, d_model, num_heads, max_seq_len=1000):
        super(PositionalMultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 标准的Q、K、V投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # 位置偏置参数
        self.position_bias = nn.Parameter(torch.randn(num_heads, max_seq_len, max_seq_len))
        
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # 计算Q、K、V
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 计算内容注意力分数
        content_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 添加位置偏置
        position_scores = self.position_bias[:, :seq_len, :seq_len].unsqueeze(0)
        
        # 总注意力分数
        total_scores = content_scores + position_scores
        
        # 计算注意力权重
        attention_weights = F.softmax(total_scores, dim=-1)
        
        # 应用注意力
        output = torch.matmul(attention_weights, V)
        
        # 拼接和投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        return output, attention_weights

# 测试位置感知注意力
pos_mha = PositionalMultiHeadAttention(d_model=64, num_heads=4)
test_input = torch.randn(1, 8, 64)

pos_output, pos_weights = pos_mha(test_input)
print(f"位置感知多头注意力输出形状: {pos_output.shape}")
print(f"位置感知注意力权重形状: {pos_weights.shape}")

# 可视化位置偏置
position_bias = pos_mha.position_bias[:4, :8, :8].detach().numpy()

fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes = axes.flatten()

for i in range(4):
    sns.heatmap(position_bias[i], annot=True, fmt='.2f', cmap='RdBu_r',
                center=0, ax=axes[i])
    axes[i].set_title(f'头{i+1}的位置偏置')

plt.tight_layout()
plt.show()

### 练习2：注意力头功能分析

In [None]:
def functional_analysis_of_heads(attention_weights, tokens=None):
    """
    分析不同注意力头的功能特化
    """
    weights = attention_weights[0].detach().numpy()  # [num_heads, seq_len, seq_len]
    num_heads, seq_len, _ = weights.shape
    
    print("=== 注意力头功能分析 ===")
    
    for head_idx in range(num_heads):
        head_weight = weights[head_idx]
        
        # 分析注意力模式
        diagonal_strength = np.mean(np.diag(head_weight))
        
        # 分析是否关注相邻位置
        adjacent_strength = 0
        for i in range(seq_len - 1):
            adjacent_strength += head_weight[i, i+1] + head_weight[i+1, i]
        adjacent_strength /= (2 * (seq_len - 1))
        
        # 分析是否有全局注意力模式
        global_strength = np.mean(head_weight) - diagonal_strength
        
        # 分析注意力的方向性（前向 vs 后向）
        forward_strength = 0
        backward_strength = 0
        
        for i in range(seq_len):
            for j in range(seq_len):
                if i < j:  # 前向
                    forward_strength += head_weight[i, j]
                elif i > j:  # 后向
                    backward_strength += head_weight[i, j]
        
        forward_strength /= (seq_len * (seq_len - 1) / 2)
        backward_strength /= (seq_len * (seq_len - 1) / 2)
        
        # 确定头的功能类型
        if diagonal_strength > 0.4:
            head_type = "自关注型"
        elif adjacent_strength > 0.3:
            head_type = "邻接型"
        elif forward_strength > backward_strength * 1.5:
            head_type = "前向型"
        elif backward_strength > forward_strength * 1.5:
            head_type = "后向型"
        else:
            head_type = "全局型"
        
        print(f"\n头 {head_idx + 1} - {head_type}:")
        print(f"  自关注强度: {diagonal_strength:.3f}")
        print(f"  邻接关注强度: {adjacent_strength:.3f}")
        print(f"  全局关注强度: {global_strength:.3f}")
        print(f"  前向关注强度: {forward_strength:.3f}")
        print(f"  后向关注强度: {backward_strength:.3f}")
    
    # 可视化头的功能分布
    fig, axes = plt.subplots(2, (num_heads + 1) // 2, figsize=(15, 8))
    if num_heads <= 2:
        axes = axes.reshape(2, -1)
    
    for i in range(num_heads):
        row = i // ((num_heads + 1) // 2)
        col = i % ((num_heads + 1) // 2)
        
        # 绘制注意力模式
        im = axes[row, col].imshow(weights[i], cmap='Blues')
        axes[row, col].set_title(f'头 {i+1}')
        
        # 添加网格线突出对角线和邻接线
        axes[row, col].plot([0, seq_len-1], [0, seq_len-1], 'r--', alpha=0.5, linewidth=1)
        if seq_len > 1:
            axes[row, col].plot([0, seq_len-2], [1, seq_len-1], 'g--', alpha=0.5, linewidth=1)
            axes[row, col].plot([1, seq_len-1], [0, seq_len-2], 'g--', alpha=0.5, linewidth=1)
    
    # 隐藏多余的子图
    for i in range(num_heads, axes.size):
        row = i // ((num_heads + 1) // 2)
        col = i % ((num_heads + 1) // 2)
        axes[row, col].set_visible(False)
    
    plt.tight_layout()
    plt.show()

# 分析之前计算的注意力头功能
functional_analysis_of_heads(attention_weights)

## 总结

在这个教程中，我们深入学习了多头注意力机制：

### 关键概念：
1. **多头机制**：并行学习多种注意力模式
2. **子空间投影**：将输入投影到不同的表示子空间
3. **注意力融合**：拼接多个头的输出并投影
4. **头的多样性**：不同头学习不同类型的关系

### 数学公式：
$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$
$$\text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

### 重要优势：
- **表达能力**：能够捕获多种类型的依赖关系
- **并行性**：所有头可以并行计算
- **专业化**：不同头可以专注于不同的语言现象
- **鲁棒性**：多个头提供冗余和互补信息

### 变体和优化：
- **多查询注意力（MQA）**：减少KV参数，提高效率
- **分组查询注意力（GQA）**：在效率和性能间平衡
- **位置感知注意力**：显式建模位置信息

### 实际应用洞察：
- 不同的头往往学习不同的语言学功能
- 有些头关注语法关系，有些关注语义关系
- 头的多样性是模型性能的重要指标

多头注意力是Transformer强大表达能力的关键来源。在下一个教程中，我们将学习**位置编码（Positional Encoding）**，它解决了注意力机制中缺乏位置信息的问题。

### 下一步学习：
- [04-positional-encoding.ipynb](04-positional-encoding.ipynb) - 位置编码详解