In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import math
from torch import nn

In [12]:
class GroupQueryAttn(nn.Module):
    def __init__(self, d_model, n_heads, n_groups):
        super(GroupQueryAttn, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        
        assert d_model % n_heads == 0
        self.n_heads_groups = self.n_heads // self.n_groups
        self.head_dim = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, self.n_groups * self.head_dim)
        self.w_v = nn.Linear(d_model, self.n_groups * self.head_dim)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        
    def expand(self, data):
        batch, time = data.shape[0], data.shape[2]
        data = data[:,:,None,:,:].expand(batch, self.n_groups, self.n_heads_groups, time, self.head_dim).contiguous()
        data = data.view(batch, self.n_groups * self.n_heads_groups, time, self.head_dim)
        return data
        
    def forward(self, q, k, v, mask=None):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)
        
        batch = q.shape[0]
        q = q.view(batch, -1, self.n_groups * self.n_heads_groups, self.head_dim).permute(0, 2, 1, 3)
        print(q.shape)
        k = k.view(batch, -1, self.n_groups, self.head_dim).permute(0, 2, 1, 3)
        print(k.shape)
        v = v.view(batch, -1, self.n_groups, self.head_dim).permute(0, 2, 1, 3)
        print(v.shape)
        
        k = self.expand(k)
        print(k.shape)
        v = self.expand(v)
        print(v.shape)
        score = q @ k.transpose(2, 3) / math.sqrt(self.head_dim)
        
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score) @ v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, -1, self.d_model)
        print(score.shape)
        output = self.w_combine(score)
        print(output.shape)
        return output
        
        
        
        

In [13]:
X = torch.randn(1, 32, 16)  # bs, len, dim
X.shape

In [14]:
d_model = 16
n_head = 8
n_groups = 4

In [15]:
attention = GroupQueryAttn(d_model, n_head, n_groups)
output = attention(X, X, X)
print(output, output.shape)