In [2]:
from turtle import forward
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):
    
    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()
    
        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads
        
        assert d_model % num_heads == 0 # d_model must be multiple of num_heads
        
        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)
        
        self.scale = 1 / sqrt(d_model // d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch, n, dim_in = x.shape # x: shape(batch, n, dim_in)
        assert dim_in == self.dim_in
        
        nh = self.num_heads
        dk = self.d_model // nh
        
        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1,2) # (batch,nh,n,dk)
        
        dist = torch.matmul(q,k.transpose(2,3)) * self.scale # (batch,nh,n,n)
        dist = torch.softmax(dist, dim=-1)
        
        att = torch.matmul(dist, x)
        att = att.transpose(1,2).reshape(batch, n, self.d_model)
        
        output = self.fc(att)
        
        return output

x = torch.rand((1,4,2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)
output = multi_head_att(x)

print(x, '\n', output)

tensor([[[0.1659, 0.0145],
         [0.3058, 0.5696],
         [0.6783, 0.3965],
         [0.0019, 0.6310]]]) 
 tensor([[[ 6.6615e-01,  2.7767e-03,  3.1097e-02, -3.9248e-01, -4.4692e-01,
          -3.3785e-01],
         [ 6.6362e-01, -1.6762e-03,  2.8390e-02, -3.9283e-01, -4.4331e-01,
          -3.3972e-01],
         [ 6.6263e-01, -3.0069e-03,  2.7275e-02, -3.9255e-01, -4.4136e-01,
          -3.3978e-01],
         [ 6.6470e-01, -8.4013e-05,  2.9593e-02, -3.9300e-01, -4.4528e-01,
          -3.3943e-01]]], grad_fn=<AddBackward0>)


In [7]:
x = torch.rand((1,4,2))

res1 = x @ x.transpose(1,2)
res2 = torch.matmul(x, x.transpose(1,2))

print(res1, '\n', res2)

tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]]) 
 tensor([[[0.4200, 0.3127, 0.5360, 0.5334],
         [0.3127, 0.3135, 0.4851, 0.5296],
         [0.5360, 0.4851, 0.7758, 0.8220],
         [0.5334, 0.5296, 0.8220, 0.8949]]])
