In [1]:
class GroupQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, head_num, group_num):
        super(GroupQueryAttention, self).__init__()
        self.hidden_size = hidden_size
        self.head_num = head_num
        self.group_num = group_num
        self.head_dim = self.hidden_size // head_num

        self.q_linear = torch.nn.Linear(hidden_size, hidden_size)
        self.k_linear = torch.nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.v_linear = torch.nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.o_linear = torch.nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, mask=None):
        batch_size = hidden_state.size()[0]

        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = self.split_head(query)
        key = self.split_head(key, self.group_num)
        value = self.split_head(value, self.group_num)

        atten_score = torch.matmul(query, key.transpose(-1,-2)) / torch.sqrt(torch.tensor(self.head_dim))

        if mask:
            atten_score += mask*(-1e-9)

        atten_score = torch.softmax(atten_score, dim=-1)
        output = torch.matmul(atten_score, value)
        output = output.transpose(-1,-2).contiguous().view(batch_size, -1, self.group_num*self.head_dim)

        return self.o_linear(output)

    def split_head(self, x, group_num):
        batch_size, seq_len = x.size()[:2]

        if group_num:
            x = x.view(batch_size, -1, self.group_num, self.head_dim).transpose(1,2)
            x = x[:,:,None,:,:].expand(batch_size, self.group_num, self.head_num//self.group_num, seq_len, self.head_dim).reshape(batch_size, self.head_num, seq_len, self.head_dim)
        else:
            x = x.view(batch_size, -1, self.head_num, self.head_dim).transpose(1,2)

        return x



NameError: name 'torch' is not defined

In [None]:
import torch
import torch.nn as nn

class CausalGroupQueryAttention(torch.nn.Module):
    """
    实现了一个基于组查询的因果注意力机制模块。

    该模块通过将隐藏状态线性变换为查询、键和值，然后在多个注意力头上进行组查询，
    最后通过一个线性变换整合结果，输出最终的注意力结果。特别地，该模块实现了因果注意力，
    确保每个位置只能关注到它之前的位置。

    参数:
    hidden_size (int): 隐藏状态的维度。
    head_num (int): 注意力头的数量。
    group_num (int): 查询组的数量。
    """
    def __init__(self, hidden_size, head_num, group_num):
        super(CausalGroupQueryAttention, self).__init__()
        # 初始化隐藏状态维度、注意力头数量和查询组数量
        self.hidden_size = hidden_size
        self.head_num = head_num
        self.group_num = group_num
        # 计算每个注意力头的维度
        self.head_dim = self.hidden_size // head_num

        # 初始化线性变换层，用于生成查询、键、值和最终的输出整合
        self.q_linear = torch.nn.Linear(hidden_size, hidden_size)
        self.k_linear = torch.nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.v_linear = torch.nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.o_linear = torch.nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, mask=None):
        """
        前向传播函数，实现了组查询因果注意力机制。

        参数:
        hidden_state (Tensor): 输入的隐藏状态，形状为[batch_size, seq_len, hidden_size]。
        mask (Tensor, optional): 注意力掩码，用于指定某些位置的注意力得分为极小值，以避免在计算注意力时考虑这些位置。

        返回:
        Tensor: 注意力机制整合后的输出，形状为[batch_size, seq_len, hidden_size]。
        """
        batch_size, seq_len, _ = hidden_state.size()

        # 对隐藏状态进行线性变换，生成查询、键和值
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        # 将查询、键和值分割为多个注意力头
        query = self.split_head(query)
        key = self.split_head(key, self.group_num)
        value = self.split_head(value, self.group_num)

        # 计算注意力分数
        atten_score = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        # 生成因果掩码，确保每个位置只能看到它之前的位置
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=hidden_state.device)).view(1, 1, seq_len, seq_len)
        causal_mask = causal_mask * -1e9

        if mask is not None:
            mask = mask[:, None, None, :] + causal_mask
        else:
            mask = causal_mask

        # 应用掩码atten_score的形状 batch_size, num_heads, seq_len, seq_len
        # 这里mask通过广播机制，加到了atten_score上
        atten_score += mask

        # 计算注意力权重
        atten_weights = torch.softmax(atten_score, dim=-1)

        # 计算加权和
        output = torch.matmul(atten_weights, value)
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.group_num * self.head_dim)

        return self.o_linear(output)

    def split_head(self, x, group_num=None):
        """
        将输入张量分割为多个注意力头。

        参数:
        x (Tensor): 输入张量，形状为[batch_size, seq_len, hidden_size]。
        group_num (int, optional): 查询组的数量。

        返回:
        Tensor: 分割后的张量，形状为[batch_size, head_num, seq_len, head_dim]。
        """
        batch_size, seq_len = x.size()[:2]

        if group_num is not None:
            x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1, 2)
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.head_num // group_num, seq_len, self.head_dim).reshape(batch_size, self.head_num, seq_len, self.head_dim)
        else:
            x = x.view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)

        return x