In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import math
from torch import nn

In [12]:
class GroupQueryAttn(nn.Module):
    def __init__(self, d_model, n_heads, n_groups):
        super(GroupQueryAttn, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        
        assert d_model % n_heads == 0
        self.n_heads_groups = self.n_heads // self.n_groups
        self.head_dim = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, self.n_groups * self.head_dim)
        self.w_v = nn.Linear(d_model, self.n_groups * self.head_dim)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        
    def expand(self, data):
        batch, time = data.shape[0], data.shape[2]
        data = data[:,:,None,:,:].expand(batch, self.n_groups, self.n_heads_groups, time, self.head_dim).contiguous()
        data = data.view(batch, self.n_groups * self.n_heads_groups, time, self.head_dim)
        return data
        
    def forward(self, q, k, v, mask=None):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)
        
        batch = q.shape[0]
        q = q.view(batch, -1, self.n_groups * self.n_heads_groups, self.head_dim).permute(0, 2, 1, 3)
        print(q.shape)
        k = k.view(batch, -1, self.n_groups, self.head_dim).permute(0, 2, 1, 3)
        print(k.shape)
        v = v.view(batch, -1, self.n_groups, self.head_dim).permute(0, 2, 1, 3)
        print(v.shape)
        
        k = self.expand(k)
        print(k.shape)
        v = self.expand(v)
        print(v.shape)
        score = q @ k.transpose(2, 3) / math.sqrt(self.head_dim)
        
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score) @ v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, -1, self.d_model)
        print(score.shape)
        output = self.w_combine(score)
        print(output.shape)
        return output
        
        
        
        

In [13]:
X = torch.randn(1, 32, 16)  # bs, len, dim
X.shape

torch.Size([1, 32, 16])

In [14]:
d_model = 16
n_head = 8
n_groups = 4

In [15]:
attention = GroupQueryAttn(d_model, n_head, n_groups)
output = attention(X, X, X)
print(output, output.shape)

torch.Size([1, 8, 32, 2])
torch.Size([1, 4, 32, 2])
torch.Size([1, 4, 32, 2])
torch.Size([1, 8, 32, 2])
torch.Size([1, 8, 32, 2])
torch.Size([1, 32, 16])
torch.Size([1, 32, 16])
tensor([[[ 0.1481, -0.2424, -0.0826,  0.1685,  0.1371, -0.1359,  0.2079,
          -0.1224, -0.1352, -0.1071,  0.0779,  0.0281, -0.0310, -0.0415,
          -0.0713,  0.2314],
         [ 0.0865, -0.1925, -0.0885,  0.1693,  0.1754, -0.1811,  0.2080,
          -0.0356, -0.1177, -0.1253,  0.0495,  0.0947, -0.0344, -0.0781,
           0.0077,  0.2971],
         [ 0.1401, -0.2101, -0.1021,  0.1176,  0.0850, -0.1783,  0.1779,
          -0.0417, -0.1863, -0.0942,  0.1125,  0.0209, -0.0358, -0.0556,
          -0.0383,  0.2139],
         [ 0.0888, -0.1820, -0.1283,  0.1406,  0.1148, -0.1897,  0.2224,
          -0.0355, -0.1422, -0.0877,  0.1177,  0.0312, -0.0563, -0.0675,
          -0.0258,  0.2513],
         [ 0.1265, -0.1809, -0.0800,  0.1238,  0.1163, -0.1718,  0.1594,
          -0.0184, -0.1565, -0.0904,  0.0933,  0.