In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dimension, head_num, dropout=0.1):
        super().__init__()
        assert dimension % head_num == 0
        self.d_model = dimension
        self.n_head = head_num

        # 차원을 head의 갯수로 나눈 값을 q, k, v의 차원으로 결정
        self.d_q = dimension // head_num
        self.d_k = dimension // head_num
        self.d_v = dimension // head_num
        
        self.w_q = nn.Linear(dimension, dimension)
        self.w_k = nn.Linear(dimension, dimension)
        self.w_v = nn.Linear(dimension, dimension)
        self.w_o = nn.Linear(dimension, dimension)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, q, k, v, mask = None):
        matmul_qk = torch.matmul(q, k.transpose(-1, -2))
        dk = torch.tensor(k.shape[-1], dtype=torch.float32)
        attention_score = matmul_qk / torch.sqrt(dk)
        if mask is not None: #디코더 부분에서 자기 자신보다 미래에 있는 단어들은 참고하지 못하도록 하는 마스킹
            attention_score += (mask * -1e9)
        attention_distribution = torch.softmax(attention_score, dim = -1)
        attention_value = torch.matmul(attention_distribution, v)
        
        return attention_value
    
    def forward(self, q, k, v, mask = None):
        batch_size = q.size(0)
        q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_q).transpose(1,2)
        k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2)
        v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_v).transpose(1,2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        q = self.dropout(q)
        k = self.dropout(k)
        v = self.dropout(v)
        
        output = self.MultiHeadAttention(q, k, v, mask)
        output = output.transpose(1,2).contiguous().view(batch_size, -1, self.n_head * self.d_v)
        #output = self.w_o(output)
        #output = self.dropout(output)
        return output
        
        

In [None]:
class AddLayerNorm(nn.Module):
    def __init__(self):
        super().__init__()
        
    def layer_norm(self, x, eps = 1e-6):
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)
        
        return (x - mean) / (std + eps)
    
    def forward(self, input, residual):
        return residual + self.layer_norm(input)

In [None]:
class Encoder(nn.Module):
    def __init__(self, dimension=512, head=8):
        super().__init__()
        self.multihead = MultiHeadAttention(dimension, head)
        self.residual_layer1 = AddLayerNorm()
        self.feed_forward = FeedForward(dimension)
        self.residual_layer2 = AddLayerNorm()
        
    def forward(self, q, k, v):
        multihead_output = self.multihead(q, k, v)
        layer1_output = self.residual_layer1(q, multihead_output)
        feed_forward_output = self.feed_forward(layer1_output)
        output = self.residual_layer2(layer1_output, feed_forward_output)
        
        return output