In [None]:
import torch

class MulitiheadAttent(torch.nn.Module):
    def __init__(self, num_head, hidden_size):
        super(MulitiheadAttent, self).__init__()

        self.hidden_size = hidden_size
        self.hidden_dim = hidden_size // num_head

        self.q_linear = torch.nn.Linear(hidden_size, hidden_size)
        self.k_linear = torch.nn.Linear(hidden_size, hidden_size)
        self.v_linear = torch.nn.Linear(hidden_size, hidden_size)
        self.out_linear = torch.nn.Linear(hidden_size, hidden_size)

    def forward(self, x, mask):
        batch_size = x.size()[0]

        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        q = self.split_head(q)
        k = self.split_head(k)
        v = self.split_head(v)

        attn_score = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.hidden_dim))

        if mask != None:
            attn_score += mask* -1e-9

        attn_prob = torch.softmax(attn_score, dim=-1)
        attn_out = torch.matmul(attn_prob, v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size)
        out = self.out_linear(attn_out)

        return out


    def split_head(self, x):
        batch_size = x.size()[0]
        return x.view(batch_size, -1, self.num_head, self.hidden_dim)