# Multi-head Self-Attention (MHA) 多头自注意力

$$
head_i = Attention(Q_i, K_i, V_i)
$$
$$
MultiHead = Concat(head_1, head_2, \ldots, head_h)W^O

In [2]:
import torch
import torch.nn as nn
import math

In [3]:

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, att_drop_p: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0
        self.head_dim = hidden_dim // num_heads
        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.att_drop = nn.Dropout(att_drop_p)

    def forward(self, X:torch.Tensor, att_mask:torch.Tensor):
        # att_mask's shape: (b, num_heads, s, s)
        batch_size, seq_len, _ = X.shape
        QKV = self.qkv_proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1) # shape: (b, s, h)

        # (b, s, hidden_dim) -> (b, s, head_dim * num_heads) -> (b, s, num_heads, head_dim) -> (b, num_heads, s, head_dim)
        q_state = torch.einsum('bsnh -> bnsh', Q.view(batch_size, seq_len, self.num_heads, self.head_dim))
        k_state = torch.einsum('bsnh -> bnsh', K.view(batch_size, seq_len, self.num_heads, self.head_dim))
        v_state = torch.einsum('bsnh -> bnsh', V.view(batch_size, seq_len, self.num_heads, self.head_dim))

        # (b, num_heads, s, s)
        att_value = (q_state @ k_state.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if att_mask is not None:
            att_value.masked_fill_(att_mask == 0, float('-inf'))
        att_weight = self.att_drop(torch.softmax(att_value, dim=-1)) # (b, num_heads, s, s)
        o_state = att_weight @ v_state # (b, num_heads, s, head_dim)
        O = o_state.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(O)
        return output

In [4]:

attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128) # b=3, s=2, hidden_dim=128
net = MultiHeadAttention(128, 8) # num_heads=8, hidden_dim=128, head_dim=18
net(x, attention_mask).shape