In [1]:
import torch
from torch import nn
import math
import einops

torch.manual_seed(42)

<torch._C.Generator at 0x25950427d30>

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        
        self.sqrt_v = math.sqrt(dim_out) 
        self.to_qkv = nn.Linear(dim_in, 3*dim_out)
        
    def forward(self, x):
        qkv = self.to_qkv(x)
        q, k, v = einops.rearrange(qkv, "h (k w) -> k h w", k=3)

        return torch.sigmoid(q @ k.T) @ v / self.sqrt_v

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_in, dim_out, heads_num):
        super().__init__()
        
        self.heads_num = heads_num
        self.heads = [SelfAttention(dim_in, dim_out) for _ in range(heads_num)]
        self.epic_w = nn.Linear(heads_num * dim_out, dim_in)
        
    def forward(self, x):
        outs = [head(x) for head in self.heads]
        outs = einops.rearrange(outs, "head h w -> h (head w)", head=self.heads_num)
        x = self.epic_w(outs)
        
        return x

In [4]:
x = torch.rand([3, 6])
x

tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739]])

In [5]:
mha = MultiHeadAttention(6, 4, 2)

In [6]:
a = mha(x)

In [7]:
a

tensor([[ 0.2661,  0.1551, -0.2437, -0.3507, -0.1366, -0.3805],
        [ 0.2552,  0.1484, -0.2105, -0.3468, -0.1555, -0.3830],
        [ 0.2628,  0.1542, -0.2322, -0.3542, -0.1401, -0.3871]],
       grad_fn=<AddmmBackward>)