In [17]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time

import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context='talk')


In [18]:
def attention(Q, K, V, mask=None, dropout=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, V), p_attn
    

In [19]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.ones(features))
        self.eps = eps
    
    def forward(self, x):
        mean = torch.mean(x, -1, keepdim=True)
        std  = torch.mean(x, -1, keepdim=True)
        return self.a_2 *  (x - mean) / (std + self.eps) + self.b_2

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [21]:
class SubLayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout  = dropout
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
        

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer1 = SubLayerConnection(size, dropout)
        self.sublayer2 = SubLayerConnection(size, dropout)
        self.size = size
    
    def forward(self, x, mask):
        self_attention_layer = lambda a : self.self_attn(a, a, a, mask)
        x = self.sublayer1(x, self_attention_layer)
        return self.sublayer2(x, self.feed_forward)
        

In [23]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab)
        
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

In [24]:
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])
        self.norm = LayerNorm(layer.size)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [27]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer1 = SubLayerConnection(size, dropout)
        self.sublayer2 = SubLayerConnection(size, dropout)
        self.sublayer3 = SubLayerConnection(size, dropout)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer1(x, lambda a : self.self_attn(a, a, a, tgt_mask))
        x = self.sublayer2(x, lambda a : self.self_attn(a, m, m, src_mask))
        return self.sublayer3(x, self.feed_forward)

In [28]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(h)])
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, Q, K, V, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = Q.size(0)
        Q, K, V = [
            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
            for l, x in zip(self.linears, (Q, K, V))
        ]
        
        x, self.attn = attention(Q, K, V, mask=mask, dropout=self.dropout)
        
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
            

In [29]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, decoder, encoder, src_embed, tgt_embed, generator):
        super().__init__()
        self.decoder = decoder
        self.encoder = encoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, tgt, tgt_mask, memory, src_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decode(
            tgt=tgt,
            memory=self.encode(src, src_mask),
            src_mask=src_mask,
            tgt_mask=tgt_mask)
        

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.look_up = nn.Embedding(vocab, d_model)
        self.d_model = d_model
    
    def forward(self, x):
        return self.look_up(x) * math.sqrt(self.d_model)
    