In [1]:
import numpy as np
import torch
import torch.nn as nn

In [8]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, dim_input, dim_k, dim_v, num_heads):
        super(SelfAttentionMultiHead, self).__init__()
        assert (dim_k % num_heads == 0)
        assert (dim_v % num_heads == 0)

        self.q = nn.Linear(dim_input, dim_k)
        self.k = nn.Linear(dim_input, dim_k)
        self.v = nn.Linear(dim_input, dim_v)

        self.dim_k = dim_k
        self.dim_v = dim_v
        self.num_heads = num_heads
        self._norm_factor = 1 / np.sqrt(dim_k)
    

    def forward(self, x):
        Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k // self.num_heads)
        K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k // self.num_heads)
        V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v // self.num_heads)

        atten = nn.Softmax(dim=-1)(
                    torch.matmul(Q, K.permute(0, 1, 3, 2)) * self._norm_factor
                    )
        
        output = torch.matmul(atten, V).reshape(x.size(0), x.size(1), -1)
        return output

In [9]:
X = torch.randn(100, 15, 16)
print( X.size() )

self_attention = SelfAttentionMultiHead(16, 16, 12, 4)
res = self_attention(X)
print(res.size())

torch.Size([100, 15, 16])
torch.Size([100, 15, 12])
