### 单头自注意力机制

In [4]:
import torch
import torch.nn as nn

from math import sqrt

In [45]:
class SelfAttention(nn.Module):
    def __init__(self, dim_in, dim_k, dim_v):
        super().__init__()
        self.dim_in = dim_in
        self.dim_k = dim_k
        self.dim_v = dim_v
        self.linear_q = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_in, dim_v, bias=False)
        self._norm_fact = 1 / sqrt(dim_k)
    
    def forward(self, x):
        # x.shape: (batch, n, dim_in)
        batch, n, dim_in = x.shape
        assert dim_in == self.dim_in
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        dist = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact
        dist = torch.softmax(dist, dim=-1)
        att = torch.bmm(dist, v)
        return att

In [42]:
dim_in, dim_k, dim_v = 4, 5, 6
batch, n, dim_in = 2, 3, 4
_norm_fact = 1 / sqrt(dim_k)
x = torch.randn(batch, n, dim_in)
linear_q = nn.Linear(dim_in, dim_k, bias=False)
linear_k = nn.Linear(dim_in, dim_k, bias=False)
linear_v = nn.Linear(dim_in, dim_v, bias=False)

In [43]:
q = linear_q(x)
k = linear_k(x)
v = linear_v(x)
dist = torch.bmm(q, k.transpose(1, 2)) * _norm_fact
dist = torch.softmax(dist, dim=-1)
att = torch.bmm(dist, v)
att


tensor([[[-0.1170, -0.0199, -0.4891,  0.1782, -0.0399,  0.1063],
         [-0.1161, -0.0436, -0.4582,  0.1723, -0.0123,  0.0957],
         [-0.1250, -0.0564, -0.4577,  0.1804,  0.0045,  0.0994]],

        [[ 0.4761, -0.1290, -0.3365, -0.3308, -0.0623, -0.3062],
         [ 0.4869, -0.0480, -0.2797, -0.3377, -0.0965, -0.2591],
         [ 0.5008,  0.0143, -0.2248, -0.3498, -0.1231, -0.2280]]],
       grad_fn=<BmmBackward0>)

### 多头自注意力机制

In [44]:
import torch
import torch.nn as nn

from math import sqrt

In [75]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim_in, dim_k, dim_v, num_heads=8):
        super().__init__()
        assert dim_k % num_heads == 0 and dim_v % num_heads == 0,\
            "dim_k and dim_v must be multiple of num_heads"
        self.dim_in = dim_in
        self.dim_k = dim_k
        self.dim_v = dim_v
        self.num_heads = num_heads
        self.linear_q = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_in, dim_v, bias=False)
        self._norm_fact = 1 / sqrt(dim_k // num_heads)

    def forward(self, x):
        # x.shape: (batch, n, dim_in)
        batch, n, dim_in = x.shape
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.dim_k // nh # dim_k of each head
        dv = self.dim_v // nh # dim_v of each head

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
        v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2) # (batch, nh, n, dv)

        dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact # (batch, nh, n, n)
        dist = torch.softmax(dist, dim=-1)

        att = torch.matmul(dist, v)
        att = att.transpose(1, 2).reshape(batch, n, self.dim_v) # batch, n, dim_v

        return att

In [55]:
dim_in, dim_k, dim_v, nh = 4, 6, 8, 2
batch, n, dim_in = 2, 3, 4
_norm_fact = 1 / sqrt(dim_k // nh)
dk, dv = dim_k // nh, dim_v // nh
x = torch.randn(batch, n, dim_in)
linear_q = nn.Linear(dim_in, dim_k, bias=False)
linear_k = nn.Linear(dim_in, dim_k, bias=False)
linear_v = nn.Linear(dim_in, dim_v, bias=False)

In [72]:
q = linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2)
k = linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2)
v = linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2)
dist = torch.matmul(q, k.transpose(2, 3)) * _norm_fact
dist = torch.softmax(dist, dim=-1)
att = torch.matmul(dist, v)
att = att.transpose(1, 2).reshape(batch, n, dim_v)

In [78]:
MHA = MultiHeadSelfAttention(dim_in, dim_k, dim_v, num_heads=nh)
MHA.forward(x)

tensor([[[ 0.7564, -0.4551, -0.2517, -0.0053, -0.3878,  0.0794,  0.5259,
           0.4184],
         [ 0.6843, -0.4034, -0.2421, -0.0107, -0.3764,  0.1330,  0.4599,
           0.4155],
         [ 0.9905, -0.6450, -0.4833, -0.2000, -0.4009,  0.1281,  0.5068,
           0.4333]],

        [[ 0.6283, -0.3416, -0.0734,  0.2018, -0.0785, -0.2899,  0.3741,
           0.2465],
         [ 0.6967, -0.3385,  0.0030,  0.3277, -0.0423, -0.1433,  0.3717,
           0.0984],
         [ 0.6710, -0.3422, -0.0310,  0.2727, -0.0586, -0.2714,  0.3703,
           0.2078]]], grad_fn=<UnsafeViewBackward0>)