In [1]:
'''
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
(https://arxiv.org/pdf/2305.13245)

The TLDR of GQA vs normal MHSA is that storing lots of keys and values for many heads 
takes up a lot of memory, so we just share keys and values across multiple heads, grouping
the keys/vals for a layer by the queries (which we keep one set of per head)
so the MHSA is now "grouped" and less expensive, especially for longer sequences
when the memory required to store lots of keys and values (see KV cache) grows large. 
'''
import torch as t
import torch.nn as nn 
import torch.nn.functional as F
import einops
from jaxtyping import Float


In [105]:
class MHA(nn.Module):
    def __init__(self, d_model, d_head, num_heads, num_groups):
        super().__init__()
        self.d_head = d_head
        self.num_head = num_heads
        # self.num_groups = num_groups
        # TODO: why Linear vs Parameter
        # shape: (num_heads, d_model, d_head)
        self.W_Q = nn.Parameter(
            t.randn((num_heads, d_model, d_head))
        )
        self.W_K = nn.Parameter(
            t.randn(num_heads, d_model, d_head)
        )
        self.W_V = nn.Parameter(
            t.randn(num_heads, d_model, d_head)
        )
        self.W_O = nn.Parameter(
            t.randn(num_heads * d_head, d_model)
        )
        
    def forward(self, embed: Float[t.Tensor, "batch seq_len d_model"]) -> Float[t.Tensor, "batch seq_len d_model"]:
        Q = einops.einsum(embed, self.W_Q, "b s dm, nh dm dh -> b nh s dh")
        K = einops.einsum(embed, self.W_K, "b s dm, nh dm dh -> b nh s dh")
        V = einops.einsum(embed, self.W_V, "b s dm, nh dm dh -> b nh s dh")
        # head_to_group = t.arange(0, self.num_head) % self.num_groups
        K_t = einops.rearrange(K, "b nh s dh -> b nh dh s")
        Q_K = (
            einops.einsum(Q, K_t, "b nh s_q dh, b nh dh s_k -> b nh s_q s_k")
            / t.sqrt(t.tensor(self.d_head))
        )
        softmaxed = t.softmax(Q_K, dim=-1)
        # TODO: softmax?
        """
        Q = b, n_q, d_head
        K = b d_head n_k
        b n_q n_k
        """
        # RuntimeError: einsum(): subscript b has size 2 for operand 1 which does not broadcast with previously seen size 4
        attn = einops.einsum(softmaxed, V, "b nh s_q s_k,  b nh s_k dh -> b nh s_q dh")
        attn_unraveled = einops.rearrange(attn, "b nh s_q dh -> b s_q (nh dh)") # TODO: is this order correct?
        output = einops.einsum(attn_unraveled, self.W_O, "b s_q nh_dh, nh_dh dm -> b s_q dm")
        
        expected = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
        )
        print(expected.shape, attn.shape)
        assert t.allclose(expected, attn)
        return output

In [None]:
class GQA(nn.Module):
    def __init__(self, d_model, d_head, num_heads, num_groups):
        super().__init__()
        assert num_heads % num_groups == 0
        self.d_model = d_model
        self.d_head = d_head
        self.num_heads = num_heads
        self.num_groups = num_groups
        # TODO: why Linear vs Parameter
        # shape: (num_heads, d_model, d_head)
        self.W_Q = nn.Parameter(
            t.randn((num_heads, d_model, d_head))
        )
        self.W_K = nn.Parameter(
            t.randn(num_groups, d_model, d_head)
        )
        self.W_V = nn.Parameter(
            t.randn(num_groups, d_model, d_head)
        )
        self.W_O = nn.Parameter(
            t.randn(num_heads * d_head, d_model)
        )
        
    def forward(self, embed: Float[t.Tensor, "batch seq_len d_model"]) -> Float[t.Tensor, "batch seq_len d_model"]:
        Q = einops.einsum(embed, self.W_Q, "b s dm, nh dm dh -> b nh s dh")
        K = einops.einsum(embed, self.W_K, "b s dm, ng dm dh -> b ng s dh")
        V = einops.einsum(embed, self.W_V, "b s dm, ng dm dh -> b ng s dh")
        # head_to_group = t.arange(0, self.num_head) % self.num_groups
        K_t = einops.rearrange(K, "b ng s dh -> b ng dh s")
        group_repeats = self.num_heads // self.num_groups
        K_t_interleaved = K_t.repeat_interleave(repeats=group_repeats, dim=1)
        # RuntimeError: einsum(): subscript b has size 8 for operand 1 which does not broadcast with previously seen size 4
        Q_K = (
            einops.einsum(Q, K_t_interleaved, "b nh s_q dh, b nh dh s_k -> b nh s_q s_k")
            / t.sqrt(t.tensor(self.d_head))
        )
        softmaxed = t.softmax(Q_K, dim=-1)
        V_interleaved = V.repeat_interleave(repeats = group_repeats, dim=1)
        attn = einops.einsum(softmaxed, V_interleaved, "b nh s_q s_k,  b nh s_k dh -> b nh s_q dh")
        attn_unraveled = einops.rearrange(attn, "b nh s_q dh -> b s_q (nh dh)") # TODO: is this order correct?
        output = einops.einsum(attn_unraveled, self.W_O, "b s_q nh_dh, nh_dh dm -> b s_q dm")
        
        expected = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            enable_gqa=True,
        )
        assert t.allclose(expected, attn)
        return output

In [111]:
mha = GQA(
    d_model=3, d_head=2, num_heads=4, num_groups=2,
)


In [112]:
resid = t.tensor(
    [
        [
            [1.2, 2.3, 3.4],
            [8.7, 7.6, 6.5],
            [5.9, 9.2, 4.2],
        ]
    ],
)

In [113]:
mha(resid)

torch.Size([1, 4, 3, 2]) torch.Size([1, 4, 2, 3])
torch.Size([1, 4, 3, 2]) torch.Size([1, 4, 3, 2])


tensor([[[-0.6028, 14.6338,  4.0317],
         [-0.6007, 14.6337,  4.0316],
         [-0.6007, 14.6337,  4.0316]]], grad_fn=<ViewBackward0>)

In [None]:
num_heads = 6
num_groups = 2
head_to_group = t.arange(0, num_heads) % num_groups

In [4]:
head_to_group

tensor([0, 1, 0, 1, 0, 1])

In [89]:
# (nh, s, dh)  (ng, s, dh)
# (6, 2, 1), (3, 1, 2)
# (6, 2, 1), (6, 1, 2) -> (6, 2, 2)
q = t.randn(6, 2, 1)
k_t = t.randn(3, 1, 2)
rep_kt = k_t.repeat_interleave(repeats=6 // 3, dim=0)

In [91]:
(q @ rep_kt).shape

torch.Size([6, 2, 2])

torch.Size([6, 1, 2])

In [83]:
import numpy as np
np.array([2, 4]) * np.array([7, 8, 2, 3])

ValueError: operands could not be broadcast together with shapes (2,) (4,) 