In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import math

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_h}} + M)V\\
\text{head}_i = \text{Attention}(Q, K, V)\\
\text{MHA}(Q, K, V) = \text{concat}(\text{head}_1, ..., \text{head}_h)W^O
$$

In [None]:
class MHA(nn.Module):
    def __init__(self, hidden_dim, head_num, dropout: float = 0.0, bias: bool = True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=bias)

        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)

    def forward(self, X: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        # X: (batch_size, seq_len, dim)
        batch, seq, dim = X.shape

        # (q, k, v): (batch, head_num, seq, head_dim)
        q = self.q_proj(X).view(batch, seq, self.head_num, self.head_dim).transpose(1, 2)
        k = self.k_proj(X).view(batch, seq, self.head_num, self.head_dim).transpose(1, 2)
        v = self.v_proj(X).view(batch, seq, self.head_num, self.head_dim).transpose(1, 2)

        # s: (batch, head_num, seq, seq)
        s = q @ k.transpose(-2, -1) * self.scale

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                s = s.masked_fill(attn_mask, float("inf"))
            else:
                s = s + attn_mask

        attn = F.softmax(s, dim=-1)
        attn = self.attn_dropout(attn)

        y = attn @ v  # (batch, head_num, seq, head_dim)
        y = y.transpose(1, 2).contiguous().view(batch, seq, dim)  # (batch, seq, dim)
        y = self.out_proj(y)  # (batch, seq, dim)
        return y

In [18]:
attn_net = MHA(hidden_dim=64, head_num=8, dropout=0.1)
X = torch.rand(2, 16, 64)
attn_net(X)

torch.Size([2, 8, 16, 8])
torch.Size([2, 16, 64])


tensor([[[-0.3285, -0.1272,  0.0492,  ..., -0.2410,  0.0295,  0.0787],
         [-0.2989, -0.1286,  0.0395,  ..., -0.2015,  0.0412,  0.0803],
         [-0.3355, -0.1289,  0.0515,  ..., -0.2269,  0.0385,  0.0600],
         ...,
         [-0.2958, -0.1206,  0.0352,  ..., -0.1935,  0.0444,  0.0738],
         [-0.2990, -0.1188,  0.0360,  ..., -0.1923,  0.0324,  0.0736],
         [-0.3254, -0.1354,  0.0525,  ..., -0.2445,  0.0353,  0.0620]],

        [[-0.2784,  0.0226,  0.0355,  ..., -0.1707,  0.0225,  0.0777],
         [-0.2904,  0.0168,  0.0490,  ..., -0.1814,  0.0368,  0.0805],
         [-0.2576,  0.0254,  0.0551,  ..., -0.1796,  0.0498,  0.0844],
         ...,
         [-0.2810,  0.0388,  0.0203,  ..., -0.1973,  0.0444,  0.0841],
         [-0.2598,  0.0082,  0.0361,  ..., -0.1557,  0.0344,  0.0813],
         [-0.2764,  0.0253,  0.0513,  ..., -0.1908,  0.0420,  0.1081]]],
       grad_fn=<ViewBackward0>)

In [19]:
import numpy as np

In [None]:
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    z = np.max(x, axis=axis, keepdims=True)
    exp = np.exp(x - z)
    return exp / np.sum(exp)

def mha(
    x: np.ndarray,
    w_qkv: np.ndarray,
    w_out: np.ndarray,
    n_heads: bool,
) -> np.ndarray:
    B, T, D = x.shape
    dh = D // n_heads
    
    qkv = x @ w_qkv
    q, k, v = np.split(qkv,3, axis=-1)
    
    q = q.reshape(B, T, n_heads, dh).transpose(0, 2, 1, 3)
    k = k.reshape(B, T, n_heads, dh).transpose(0, 2, 1, 3)
    v = v.reshape(B, T, n_heads, dh).transpose(0, 2, 1, 3)
    
    scale = 1 / np.sqrt(dh)
    s = q @ k.transpoase(0, 1, 3, 2) * scale
    
    causal_mask = np.triu(np.ones((T, T), dtype=bool), k=1)
    attn = np.where(causal_mask == 0, -1e9, s)

    ctx = softmax(s) @ v
    ctx = ctx.transpose(0, 2, 1, 3).reshape(B, T, D)
    return ctx @ w_out