In [2]:
import torch
import torch.nn as nn
import math

In [8]:
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_key_value_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads # query head个数
        self.num_key_value_heads = num_key_value_heads # key / value head个数
        assert num_heads % num_key_value_heads == 0
        self.head_dim = hidden_dim // num_heads
        assert self.head_dim % num_key_value_heads == 0
        self.q_proj = nn.Linear(hidden_dim, num_heads * self.hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, num_key_value_heads * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_key_value_heads * self.hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X: torch.tensor, att_mask: torch.tensor = None):
        # X: (b, s, hidden_dim)
        # att_mask: (b, num_heads, s, s)

        batch_size, seq_length, _ = X.shape

        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        q_state = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1,2)
        k_state = K.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1,2)
        v_state = V.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1,2)

        k_state = torch.repeat_interleave(k_state, self.num_heads // self.num_key_value_heads, dim=-1) # (batch_size, num_heads, seq_length, head_dim)
        v_state = torch.repeat_interleave(v_state, self.num_heads // self.num_key_value_heads, dim=-1) # (batch_size, num_heads, seq_length, head_dim)

        print(k_state.shape)
        print(v_state.shape)

        att_value = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)

        if att_mask is None:
            att_mask = torch.ones_like(att_value).tril()

        att_value.masked_fill_(att_mask == 0, float('-inf'))

        att_weight = torch.softmax(att_value, dim=-1) # (b, h, s, s)

        o_state = att_weight @ v_state # (b, h, s, d)

        O = self.o_proj(o_state.tranpose(1,2).contiguous().view(batch_size, seq_length, self.hidden_dim))

        return O

In [9]:
# 测试
x = torch.rand(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)
net(x).shape