In [5]:
import numpy as np
import pandas as pd
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
def get_sinusoid_encoding_tabel(n_seq, d_hidn):
    def cal_angle(position, i_hidn):
        return position / np.power(10000, 2 * (i_hidn//2) / d_hidn)
    def get_posi_angle_vec(position):
        return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]
    
    sinusoid_tabel = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_tabel[:, 0::2] = np.sin(sinusoid_tabel[:, 0::2])
    sinusoid_tabel[:, 1::2] = np.cos(sinusoid_tabel[:, 1::2])
    
    return sinusoid_tabel

In [7]:
# attention pad mask

def get_attn_pad_mask(seq_q, seq_k, i_pad):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(i_pad).unsqueeze(1).expand(batch_size, len_q, len_k) # <pad>
    return pad_attn_mask

In [8]:
# attention decoder mask

def get_attn_decoder_mask(seq):
    subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1))
    subsequent_mask = subsequent_mask.triu(diagonal=1)
    return subsequent_mask

In [9]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(config.dropout)
        self.scale = 1/(self.config.d_head ** 0.5)
        
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        scores.masked_fill_(attn_mask, -1e9)
        
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.dropout(attn_prob)
        
        context = torch.matmul(attn_prob, V)
        
        return context, attn_prob

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.W_Q = nn.Linear(in_features=self.config.d_hidn, out_features=self.config.n_head*self.config.d_head)
        self.W_K = nn.Linear(in_features=self.config.d_hidn, out_features=self.config.n_head*self.config.d_head)
        self.W_V = nn.Linear(in_features=self.config.d_hidn, out_features=self.config.n_head*self.config.d_head)
        self.scaled_dot_attn = ScaledDotProductAttention(self.config)
        self.linear = nn.Linear(self.config.n_head*self.config.d_head, self.config.d_hidn)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.size(0)
        