In [31]:
import torch
from torch import nn
import math

In [32]:
x = torch.rand(128,32,512)
d_model = 512
n_head = 8

In [33]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention,self).__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_contact = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim = -1)
        
    def forward(self, q, k, v, mask = None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head
        q,k,v = self.w_q(q), self.w_k(k), self.w_v(v)
        q = q.view(batch, time, self.n_head, n_d).permute(0,2,1,3)
        k = k.view(batch, time, self.n_head, n_d).permute(0,2,1,3)
        v = v.view(batch, time, self.n_head, n_d).permute(0,2,1,3)
        score = q@k.transpose(2,3)/math.sqrt(n_d)
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)
        score = self.softmax(score)@v
        score = score.permute(0,2,1,3).contiguous().view(batch,time,self.d_model)
        out = self.w_contact(score)
        return out
    
attention = MultiHeadAttention(d_model, n_head)

In [34]:
out = attention(x,x,x)
print(out)

tensor([[[ 0.1082, -0.2425, -0.1160,  ..., -0.0229, -0.6229,  0.0131],
         [ 0.1086, -0.2435, -0.1159,  ..., -0.0228, -0.6220,  0.0138],
         [ 0.1093, -0.2435, -0.1167,  ..., -0.0219, -0.6226,  0.0131],
         ...,
         [ 0.1085, -0.2436, -0.1158,  ..., -0.0233, -0.6225,  0.0135],
         [ 0.1085, -0.2428, -0.1159,  ..., -0.0227, -0.6227,  0.0135],
         [ 0.1085, -0.2434, -0.1167,  ..., -0.0219, -0.6229,  0.0130]],

        [[ 0.0436, -0.2910, -0.1107,  ..., -0.0446, -0.5965, -0.0157],
         [ 0.0435, -0.2916, -0.1114,  ..., -0.0446, -0.5968, -0.0169],
         [ 0.0442, -0.2900, -0.1104,  ..., -0.0444, -0.5957, -0.0167],
         ...,
         [ 0.0438, -0.2906, -0.1108,  ..., -0.0443, -0.5971, -0.0163],
         [ 0.0438, -0.2907, -0.1110,  ..., -0.0444, -0.5970, -0.0175],
         [ 0.0434, -0.2906, -0.1110,  ..., -0.0450, -0.5971, -0.0166]],

        [[ 0.0795, -0.2862, -0.1042,  ..., -0.0468, -0.6145,  0.0090],
         [ 0.0792, -0.2858, -0.1037,  ..., -0