In [None]:
请问这个代码是否有误：
import torch.nn as nn
import torch.nn.functional as F

class MHA(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, input_tensor):
        B, L, H = input_tensor.shape
        K = self.k(input_tensor).reshape(B, L, self.num_heads, H // self.num_heads).permute(0, 2, 1, 3)
        Q = self.q(input_tensor).reshape(B, L, self.num_heads, H // self.num_heads).permute(0, 2, 1, 3)
        V = self.v(input_tensor).reshape(B, L, self.num_heads, H // self.num_heads).permute(0, 2, 1, 3)

        attn_score =  F.softmax(Q @ K.transpose(2,3) / ((self.hidden_dim // self.num_heads) ** 0.5), dim=-1)
        outputs = (attn_score @ V).transpose(1,2).reshape(B, L, H)
        return outputs

class MQA(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim // self.num_heads)
        self.v = nn.Linear(hidden_dim, hidden_dim // self.num_heads)

    def forward(self, input_tensor):
        B, L, H = input_tensor.shape
        Q = self.q(input_tensor).reshape(B, L, self.num_heads, H // self.num_heads).permute(0, 2, 1, 3)
        K = self.k(input_tensor).unsqueeze(1)
        V = self.v(input_tensor).unsqueeze(1)

        attn_score =  F.softmax(Q @ K.transpose(2,3) / ((self.hidden_dim // self.num_heads) ** 0.5), dim=-1)
        outputs = (attn_score @ V).transpose(1,2).reshape(B, L, H)
        return outputs

class GQA(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_kv_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0
        assert num_heads % num_kv_heads == 0

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.group_dim = (hidden_dim // self.num_heads) * num_kv_heads

        self.q = nn.Linear(hidden_dim, self.hidden_dim)
        self.k = nn.Linear(hidden_dim, self.group_dim)
        self.v = nn.Linear(hidden_dim, self.group_dim)

    def forward(self, input_tensor):
        B, L, H = input_tensor.shape
        Q = self.q(input_tensor).reshape(B, L, self.num_kv_heads, self.num_heads // self.num_kv_heads,  H // self.num_heads).permute(0, 2, 3, 1, 4) # B, group_num, head_num // group_num, L, H // head_num
        K = self.k(input_tensor).reshape(B, L, 1, self.num_kv_heads, H // self.num_heads).permute(0, 3, 2, 1, 4)
        V = self.v(input_tensor).reshape(B, L, 1, self.num_kv_heads, H // self.num_heads).permute(0, 3, 2, 1, 4)

        attn_score =  F.softmax(Q @ K.transpose(-2,-1) / ((self.hidden_dim // self.num_heads) ** 0.5), dim=-1)
        outputs = (attn_score @ V).permute(0, 3, 1, 2, 4).reshape(B, L, H)
        return outputs
    

自带KV cache的整合版本

In [None]:
# 假设有一层 Self-Attention
class SelfAttention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)
        self.n_heads = n_heads

    def forward(self, x, past_kv=None):
        # x: [batch, seq_len, dim]
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        if past_kv is not None:
            # 拼接历史缓存
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1)
            v = torch.cat([past_v, v], dim=1)

        # 计算 Attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
        attn_output = torch.matmul(F.softmax(attn_weights, dim=-1), v)

        # 返回结果和新的缓存
        return attn_output, (k, v)

# 使用时：每次 forward 会返回新的缓存，供下一步复用

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GQA(nn.Module):
    def __init__(self, hidden_dim, group_num, head_num):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.group_num = group_num
        self.head_num = head_num
        self.w_q = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.w_k = nn.Linear(self.hidden_dim, self.hidden_dim // (head_num // group_num))
        self.w_v = nn.Linear(self.hidden_dim, self.hidden_dim // (head_num // group_num))
        self.o = nn.Linear(self.hidden_dim, self.hidden_dim)
    
    def forward(self, q, kv=None, k_cache=None, v_cache=None, attn_mask=None):
        if kv is None:
            kv = q
    
        B, Lq, H = q.shape
        B, Lkv, H = kv.shape

        Q = self.w_q(q)
        K = self.w_k(kv)
        V = self.w_v(kv)

        if k_cache:
            k_cache = torch.concat([k_cache, K], dim=1)
            Lkv = k_cache.shape[1]
        else:
            k_cache = K

        if v_cache:
            v_cache = torch.concat([v_cache, V], dim=1)
            Lkv = v_cache.shape[1]
        else:
            v_cache = V

        Q = Q.reshape(B, Lq, self.group_num, self.head_num // self.group_num, H // self.head_num).permute(0, 2, 3, 1, 4) # bgplh
        K = k_cache.reshape(B, Lkv, self.group_num, H // self.head_num).permute(0, 2, 1, 3) # bgth
        V = v_cache.reshape(B, Lkv, self.group_num, H // self.head_num).permute(0, 2, 1, 3) # bgth

        attn = torch.einsum("bgplh, bgth->bgplt", Q, K) / ((self.hidden_dim // self.head_num * self.group_num) ** 0.5) # bgplt

        tril_mask = torch.ones(1, 1, 1, Lq, Lkv).tril()
        attn = attn.masked_fill(tril_mask==0, -float('inf'))
        attn = F.softmax(attn, dim=-1)

        outputs = torch.einsum("bgplt, bgth->bgplh", attn, V) # bgplh
        outputs = outputs.permute(0,3,1,2,4).reshape(B, Lq, H)
        outputs = self.o(outputs)
        return outputs, k_cache, v_cache
    
q = torch.randn(10, 512, 768)
gpa = GQA(hidden_dim=768, group_num=2, head_num=4)
gpa(q)

(tensor([[[-6.4724e-01,  3.3060e-01,  7.0604e-01,  ...,  1.2995e-01,
           -3.8271e-02,  1.6505e-01],
          [-4.8879e-02,  2.2648e-01,  3.0446e-01,  ...,  1.2712e-01,
           -6.9186e-01,  8.4487e-02],
          [ 1.2842e-01,  1.2270e-01,  3.8339e-01,  ...,  1.7273e-01,
           -4.2586e-01,  4.5509e-02],
          ...,
          [-4.9807e-02,  9.0177e-04,  3.2556e-02,  ..., -1.9173e-02,
            2.7522e-02, -2.4489e-02],
          [-4.6460e-02,  1.7933e-04,  4.1181e-02,  ..., -1.8496e-02,
            2.1442e-02, -1.6814e-02],
          [-4.9998e-02, -6.3287e-03,  3.9826e-02,  ..., -1.0800e-02,
            1.6482e-02, -2.4108e-02]],
 
         [[ 3.8783e-01,  1.4279e-01, -2.3531e-02,  ...,  2.9379e-01,
            2.1458e-01, -2.8641e-01],
          [ 2.4096e-01,  3.3994e-01,  2.4054e-01,  ..., -2.4949e-02,
            1.5015e-01, -5.4550e-02],
          [ 2.7038e-01,  1.9394e-01,  1.2818e-01,  ..., -1.3305e-02,
            1.6878e-01,  7.3579e-02],
          ...,
    