In [2]:
import torch
import torch.nn as nn
import math

# Self Attention V1

$$
\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [3]:

class SelfAttV1(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X: torch.Tensor, use_einsum=False):
        # X's shape: (b, s, h)
        Q = self.q_proj(X) # (b, s, h)
        K = self.q_proj(X) # (b, s, h)
        V = self.q_proj(X) # (b, s, h)

        if use_einsum:
            att_value = torch.einsum('bih, bjh -> bij', Q, K) / math.sqrt(self.hidden_dim)
        else:
            att_value = Q @ K.transpose(-1,-2) / math.sqrt(self.hidden_dim) # (b, s, s)

        att_weight = torch.softmax(att_value, dim=-1)

        output = att_weight @ V # (b, s, h)

        return output


In [4]:
X = torch.randn((4, 2, 2))
net = SelfAttV1(hidden_dim = X.shape[-1])
output = net(X)
print(output.shape)
print(output)

In [5]:
output = net(X, use_einsum=True)
print(output.shape)
print(output)

# Self Attention V2
- 合并qkv的矩阵乘法，提高效率. 虽然现在大模型时代，q k v 的投影矩阵都是分开写的，这是因为现在的模型很大，本身可能会做 张量并行，流水线并行等方式，所以分开写问题也不大（分开写很清晰）。
- 加入attention mask
- output也加一个线性层

In [6]:
class SelfAttV2(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X: torch.Tensor, att_mask: torch.Tensor):
        # X's shape: (b, s, h)
        # att_mask's shape: (b, s, s)
        QKV = self.qkv_proj(X) # (b, s, h * 3)

        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        att_value = torch.einsum('bih, bjh -> bij', Q, K) / math.sqrt(self.hidden_dim)
        if att_mask is not None:
            att_value.masked_fill_(att_mask == 0, float("-inf"))
        att_weight = torch.softmax(att_value, dim=-1)
        output = self.o_proj(att_weight @ V)
        return output

In [7]:
X = torch.randn((3, 4, 2))
net = SelfAttV2(hidden_dim = X.shape[-1])
att_mask = torch.Tensor([
    [1, 1, 1, 1],
    [1, 1, 1, 0],
    [1, 0, 0, 0],
])
att_mask = att_mask.unsqueeze(dim=1).repeat(1, 4, 1)
print(att_mask.shape)
output = net(X, att_mask)
print(output.shape)
print(output)




# Self Attention V3
- 加入dropout。很奇怪的一点是，在BERT里面，这里的dropout用在了att_weight上面，不是output上面

In [8]:
class SelfAttV3(nn.Module):
    def __init__(self, hidden_dim: int, att_drop_p: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.att_drop = nn.Dropout(att_drop_p)

    def forward(self, X: torch.Tensor, att_mask: torch.Tensor):
        QKV = self.qkv_proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        att_value = torch.einsum('bih, bjh -> bij', Q, K)
        if att_mask is not None:
            att_value.masked_fill_(att_mask ==0, float('-inf'))
        att_weight = self.att_drop(torch.softmax(att_value, dim=-1))
        output = self.o_proj(att_weight @ V)
        return output

X = torch.randn((3, 4, 2))
net = SelfAttV3(hidden_dim = X.shape[-1])
att_mask = torch.Tensor([
    [1, 1, 1, 1],
    [1, 1, 1, 0],
    [1, 0, 0, 0],
])
att_mask = att_mask.unsqueeze(dim=1).repeat(1, 4, 1)
output = net(X, att_mask)
print(output.shape)
print(output)

# Multi-head Self-Attention (MHA) 多头自注意力

$$
head_i = Attention(Q_i, K_i, V_i)
$$
$$
MultiHead = Concat(head_1, head_2, \ldots, head_h)W^O

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, att_drop_p: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0
        self.head_dim = hidden_dim // num_heads
        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.att_drop = nn.Dropout(att_drop_p)

    def forward(self, X:torch.Tensor, att_mask:torch.Tensor):
        # att_mask's shape: (b, num_heads, s, s)
        batch_size, seq_len, _ = X.shape
        QKV = self.qkv_proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1) # shape: (b, s, h)

        # (b, s, hidden_dim) -> (b, s, head_dim * num_heads) -> (b, s, num_heads, head_dim) -> (b, num_heads, s, head_dim)
        q_state = torch.einsum('bsnh -> bnsh', Q.view(batch_size, seq_len, self.num_heads, self.head_dim))
        k_state = torch.einsum('bsnh -> bnsh', K.view(batch_size, seq_len, self.num_heads, self.head_dim))
        v_state = torch.einsum('bsnh -> bnsh', V.view(batch_size, seq_len, self.num_heads, self.head_dim))

        # (b, num_heads, s, s)
        att_value = (q_state @ k_state.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if att_mask is not None:
            att_value.masked_fill_(att_mask == 0, float('-inf'))
        att_weight = self.att_drop(torch.softmax(att_value, dim=-1)) # (b, num_heads, s, s)
        o_state = att_weight @ v_state # (b, num_heads, s, head_dim)
        O = o_state.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(O)
        return output

In [15]:
attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128) # b=3, s=2, hidden_dim=128
net = MultiHeadAttention(128, 8) # num_heads=8, hidden_dim=128, head_dim=18
net(x, attention_mask).shape