group query attention其实就是序列长度分为n组，然后每组内对键和值进行聚合，mean，比如W_k *X后得到的K取mean 


MHA 就是一个a有n head个输出qi i～0-（nhead-1），然后计算出n head个结果，直接concat起来乘W_v得到结果


In [None]:
import torch
import torch.nn as nn

class GroupQueryAttention(nn.Module):
    def __init__(self,d_model,n_head,n_group)->None:
        super(GroupQueryAttention,self).__init__()
        self.d_model=d_model
        self.n_head=n_head
        self.n_group=n_group
        
        assert d_model%n_head==0
        #一个group里包含几个head
        self.n_head_groups=self.n_head//self.n_group
        #一个head的维度
        self.head_dim=d_model//n_head
        
        self.w_q=nn.Linear(d_model,d_model)
        self.w_k=nn.Linear(d_model,self.n_head_groups*self.head_dim)
        self.w_v=nn.Linear(d_model,self.n_head_groups*self.head_dim)
        
        self.fc=nn.Linear(d_model,d_model)
        self.softmax=nn.Softmax(dim=-1)
        
    def expand(self,data):
        batch,time=data.shape[0],data.shape[2]
        data=data[:,:,None,:,:].expand(batch,self.n_group,self.n_head_groups,time,self.head_dim).contiguous()
        data=data.view(batch,self.n_group*self.n_head_groups,time,self.head_dim)
        return data
    
    def forward(self,q,k,v,mask=None):
        q=self.w_q(q)
        k=self.w_k(k)
        v=self.w_v(v)
        
        batch=q.shape[0]
        q=q.view(batch,-1,self.n_group*self.n_head_groups,self.head_dim).permute(0,2,1,3)
        #这里取mean了，所以只有n_group
        k=k.view(batch,-1,self.n_group,self.head_dim).permute(0,2,1,3)
        v=v.view(batch,-1,self.n_group,self.head_dim).permute(0,2,1,3)
        
        k=self.expand(k)
        v=self.expand(v)
        
        score=q@k.transpose(-1,-2)/(self.head_dim**0.5)
        if mask is not None:
            mask=mask.masked_fill(mask==0,-1e9)
        score=self.softmax(score)@v
        score.permute(0,2,1,3).contiguous().view(batch,-1,self.d_model)
        
        output=self.fc(score)
        return output
        

# GPT GQA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super(GroupQueryAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_dim // num_heads

        # Learnable linear projections for queries, keys, and values
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Linear projection for the output
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        # Input shape: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, _ = x.size()

        # Linear projections
        queries = self.q_proj(x)  # (batch_size, seq_len, embed_dim)
        keys = self.k_proj(x)     # (batch_size, seq_len, embed_dim)
        values = self.v_proj(x)   # (batch_size, seq_len, embed_dim)

        # Reshape for multi-head attention
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Group keys and values (average within groups)
        keys = keys.view(batch_size, self.num_heads, self.num_groups, -1, self.head_dim).mean(dim=3)
        values = values.view(batch_size, self.num_heads, self.num_groups, -1, self.head_dim).mean(dim=3)

        # Attention scores (scaled dot-product)
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch_size, num_heads, seq_len, num_groups)
        attn_probs = F.softmax(attn_scores, dim=-1)  # Normalize over groups

        # Attention output
        context = torch.matmul(attn_probs, values)  # (batch_size, num_heads, seq_len, head_dim)

        # Combine heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)

        # Final linear projection
        output = self.out_proj(context)
        return output

# Example usage
if __name__ == "__main__":
    batch_size = 8
    seq_len = 128
    embed_dim = 256
    num_heads = 8
    num_groups = 4

    x = torch.rand(batch_size, seq_len, embed_dim)
    gqa = GroupQueryAttention(embed_dim, num_heads, num_groups)

    output = gqa(x)
    print("Output shape:", output.shape)
