In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiQueryAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiQueryAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads."

        # 为每个头生成独立的查询，但键和值共享
        self.query_layers = nn.ModuleList([nn.Linear(self.head_dim, self.head_dim, bias=False) for _ in range(heads)])
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)


    def forward(self, value, key, query):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # 分割查询维度，以适配多个查询层
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # 为每个头应用独立的查询层
        queries = torch.cat([self.query_layers[i](queries[:, :, i, :]).unsqueeze(2) for i in range(self.heads)], dim=2)

        # 计算共享的键和值
        keys = self.key(key).reshape(N, key_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        values = self.value(value).reshape(N, value_len, self.heads, self.head_dim).permute(0, 2, 1, 3)

        # 计算注意力得分
        attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attention_probs = F.softmax(attention_scores, dim=-1)

        # 应用注意力权重到值
        context_layer = torch.matmul(attention_probs, values)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.heads * self.head_dim,)
        context_layer = context_layer.view(new_context_layer_shape)

        return self.fc_out(context_layer)

