In [1]:
import argparse
import time
import tiktoken
import torch
import torch.nn as nn

In [3]:
qh = torch.rand(2, 5, 4, 8)
kh = torch.rand(2, 5, 2, 8)

qh @ kh

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [10, 8] but got: [10, 2].

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, num_kv_groups,  qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim
        self.group_size = num_heads // num_kv_groups

        self.W_query = nn.Linear(d_in, self.num_kv_groups * self.head_dim, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, self.num_kv_groups * self.head_dim, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

        ####################################################
        # KV cache-related code
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)
        self.ptr_current_pos = 0
        ####################################################

    def forward(self, x, use_cache=False):
        b, num_tokens, d_in = x.shape

        keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        values_new = self.W_value(x)
        queries = self.W_query(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys_base = keys_new.view(b, num_tokens, self.num_kv_heads, self.head_dim).transpose(1, 2)
        values_base = values_new.view(b, num_tokens, self.num_kv_heads, self.head_dim).transpose(1, 2)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        ####################################################
        # KV cache-related
        if use_cache:
            if self.cache_k is None:
                self.cache_k, self.cache_v = keys_base, values_base
            else:
                self.cache_k = torch.cat([self.cache_k, keys_base], dim=1)
                self.cache_v = torch.cat([self.cache_v, values_base], dim=1)
            keys, values = self.cache_k, self.cache_v
        else:
            keys, values = keys_base, values_base
        ####################################################
        kyes = torch.repeat_interleave(keys, self.group_size, dim = 1)
        values = torch.repeat_interleave(values, self.group_size, dim = 1)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        ####################################################
        # causal mask
        num_tokens_Q = queries.shape[-2]
        num_tokens_K = keys.shape[-2]
        device = queries.device
        if use_cache:
            q_positions = torch.arange(
                self.ptr_current_pos,
                self.ptr_current_pos + num_tokens_Q,
                device=device,
                dtype=torch.long,
            )
            self.ptr_current_pos += num_tokens_Q
        else:
            q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long)
            self.ptr_current_pos = 0
        k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long)
        mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0)

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

    def reset_cache(self):
        self.cache_k, self.cache_v = None, None
        self.ptr_current_pos = 0