In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

### Embedding

In [35]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

### Attention

In [40]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout = 0.2):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, query, key, value, mask = None, dropout = None):
        # query: [batchsize, query_len, dim]
        # key  : [batchsize, key_len  , dim]
        # value: [batchsize, value_len, dim]

        dim = query.shape[-1]

        score = torch.bmm(query, key.transpose(-2,-1)) / np.sqrt(dim) # [batchsize, query_len, key_len]

        if mask is not None:
            socre = score.mask_fill(mask == 0, 1e-9)

        attention_weight = self.softmax(score) 

        if dropout is not None:
            attention_weight = self.dropout(attention_weight)

        return torch.bmm(attention_weight, value) # [batchsize, query_len, dim]



query = torch.randn(32, 10, 256)
key = torch.randn(32, 8, 256)
value = torch.randn(32, 8, 256)
Attention = ScaledDotProductAttention()

assert Attention(query, key, value, None, None).shape == (32, 10, 256)

In [41]:
class AdditiveAttention(nn.Module):
    def __init__(self, query_size, key_size, hidden_size, dropout = 0.2 , bias = False):
        super(AdditiveAttention, self).__init__()

        self.query  = nn.Linear(query_size, hidden_size, bias = bias)
        self.key    = nn.Linear(key_size, hidden_size, bias = bias)
        self.value  = nn.Linear(hidden_size, 1, bias = bias)
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, query, key, value, mask= None, dropout = None):
        # query : [batchsize, query_len, query_size]
        # key   : [batchsize, key_len ,  key_size]
        # value : [batchsize, value_len, value_size]

        querys = self.query(query).unsqueeze(2) #[batchsize, query_len, 1, hidden_size]
        keys   = self.key(key).unsqueeze(1)     #[batchsize, 1, key_len,  hidden_size]
        output = torch.tanh(querys + keys)   #[batchsize, query_len, key_len, num_hiddens]

        score = self.value(output).squeeze(-1) #[batchsize, query_len, key_len]

        if mask is not None:
            score = score.masked_fill(mask == 0, 1e-9)

        weight = self.softmax(score) #[batchsize, query_len, key_len] 

        if dropout is not None:
            weight = self.dropout(weight)

        return torch.bmm(weight, value) #[batchsize, query_len, key_len]
        
query = torch.randn(32, 10, 256)
key = torch.randn(32, 8, 256)
value = torch.randn(32, 8, 256)
Attention = AdditiveAttention(256, 256, 512)

assert Attention(query, key, value, None, None).shape == (32, 10, 256)

In [92]:
class MultiHeadAttention(nn.Module):
    def __init__(self,  
                 query_size,
                 key_size,  
                 value_size, 
                 hidden_size, 
                 num_heads, 
                 dropout = 0.2,
                 bias=False):
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % num_heads == 0

        self.num_heads = num_heads 
        self.hidden_size = hidden_size
        self.head_dim = hidden_size // num_heads
        self.attention = ScaledDotProductAttention(dropout)

        self.query  = nn.Linear(query_size, hidden_size, bias = bias)
        self.key    = nn.Linear(key_size, hidden_size, bias = bias)
        self.value  = nn.Linear(value_size, hidden_size, bias = bias)
        self.output = nn.Linear(hidden_size, hidden_size, bias= bias )

    def forward(self, query, key, value, mask):
        # query : [batchsize, query_len,  query_size]
        # key   : [batchsize, key_len  ,  key_size]
        # value : [batchsize, value_len , value_size]

        querys = self.query(query) #[batchsize, query_len,  hidden_size]
        keys   = self.key(key)     #[batchsize, key_len,  hidden_size]
        values = self.value(value) #[batchsize, value_len,  hidden_size]

        querys = querys.reshape(querys.shape[0], querys.shape[1], self.num_heads, self.head_dim) #[batchsize, query_len, num_heads, head_dim]
        querys = querys.permute(0,2,1,3) #[batchsize, num_heads, query_len, head_dim]
        querys = querys.reshape(-1, querys.shape[2], querys.shape[3]) #[batchsize * num_heads, query_len, head_dim]

        keys   = keys.reshape(keys.shape[0], keys.shape[1], self.num_heads, self.head_dim) #[batchsize, key_len, num_heads, head_dim]
        keys   = keys.permute(0,2,1,3) #[batchsize, num_heads, key_len, head_dim]
        keys   = keys.reshape(-1, keys.shape[2], keys.shape[3]) #[batchsize * num_heads, key_len, head_dim]

        values = values.reshape(values.shape[0], values.shape[1], self.num_heads, self.head_dim) #[batchsize, value_len, num_heads, head_dim]
        values = values.permute(0,2,1,3) #[batchsize, num_heads, value_len, head_dim]
        values = values.reshape(-1, values.shape[2], values.shape[3]) #[batchsize * num_heads, value_len, head_dim]

        output = self.attention(querys, keys, values, mask) # [batchsize * num_heads, query_len, head_dim]
        output = output.reshape(-1, self.num_heads, output.shape[1], output.shape[2]) #[batchsize, num_heads, query_len, head_dim]
        output = output.permute(0,2,1,3) #[batchsize, query_len, num_heads, head_dim]
        output = output.reshape(output.shape[0], output.shape[1], -1) #[batchsize, query_len, hidden_size]

        output = self.output(output) #[batchsize, query_len, hidden_size]

        return output


query = torch.randn(2,100,512)
key = torch.randn(2,100,512)
value = torch.randn(2,100,512)
attention = MultiHeadAttention(512, 512, 512, 512, 8)

assert attention(query, key, value, None).shape == (2,100,512)


### TransformerBlock

In [93]:
class PositionWiseFFN(nn.Module):
    def __init__(self, hidden_size, ff_dim, dropout = 0.2):
        super(PositionWiseFFN, self).__init__()
        self.linear1 = nn.Linear(hidden_size, ff_dim)
        self.relu    = nn.ReLU()
        self.linear2 = nn.Linear(ff_dim, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout(self.relu(self.linear1(x)))
        x = self.linear2(x)
        return x

x = torch.randn(2, 100, 512)
model = PositionWiseFFN(512, 2048)
assert model(x).shape == (2,100,512)

In [94]:
class AddNorm(nn.Module):
    def __init__(self, size, dropout = 0.2):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(size)

    def forward(self, x, layer_output):
        return self.ln(self.dropout(layer_output) + x)
    
addnorm = AddNorm(512, 0.1)
layer_output = torch.randn(2, 100, 512)
x = torch.randn(2, 100, 512)

assert addnorm(x, layer_output).shape == (2, 100, 512)

### Encoder

In [111]:
class TransformerBlock(nn.Module):
    def __init__(self, 
                 query_size, 
                 key_size, 
                 value_size, 
                 hidden_size, 
                 ff_dim,
                 num_heads, 
                 dropout = 0.2, 
                 use_bias = False):
        
        super(TransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(query_size, key_size, value_size, hidden_size, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(hidden_size, dropout)
        self.ffn = PositionWiseFFN(hidden_size, ff_dim)
        self.addnorm2 = AddNorm(hidden_size, dropout)

    def forward(self, query, key, value, mask):
        x1 = self.addnorm1(query , self.attention(query, key, value, mask))
        return self.addnorm2(x1, self.ffn(x1))
    

model = TransformerBlock(512, 512, 512, 512, 2048, 8)
x = torch.randn((2,100,512))
assert model(x, x, x, None).shape == (2, 100, 512)

In [112]:
class Encoder(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 hidden_size, 
                 ff_dim, 
                 num_heads, 
                 num_layers, 
                 dropout = 0.2, 
                 use_bias = False):
        super(Encoder, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = PositionalEncoding(hidden_size, dropout)
        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module("block" + str(i), 
                                   TransformerBlock(hidden_size, hidden_size, hidden_size, hidden_size, ff_dim, num_heads, dropout, use_bias))
            
    def forward(self, x, mask):
        x = self.pos_encoding(self.embedding(x) * np.sqrt(self.hidden_size))
        for block in self.blocks:
            x = block(x, x, x, mask)

        return x 


encoder = Encoder(1024, 512, 2048, 8 , 8)
x = torch.ones((2,100), dtype = torch.long)
assert encoder(x, None).shape == (2, 100, 512)               

### Decoder

In [25]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

In [117]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 query_size, 
                 key_size, 
                 value_size, 
                 hidden_size,
                 ff_dim, 
                 num_heads,
                 dropout  = 0.1,
                 use_bias = False):
        super(DecoderBlock, self).__init__()

        self.attention = MultiHeadAttention(query_size, key_size, value_size, hidden_size, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(hidden_size, dropout)
        self.TransformerBlock = TransformerBlock(query_size, key_size, value_size, hidden_size, ff_dim, num_heads, dropout, use_bias)

    def forward(self, query, key, value, trg_mask, src_mask):
        x = self.attention(query, query, query, trg_mask)
        x = self.addnorm1(query, x)
        x = self.TransformerBlock(x, key, value, src_mask)
        return x 
    
model = DecoderBlock(512, 512, 512, 512, 2048, 8)
x = torch.randn((2,100,512))
assert model(x, x, x, None, None).shape == (2,100,512)
        

In [121]:
class Decoder(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 hidden_size,
                 ff_dim, 
                 num_heads, 
                 num_layers, 
                 dropout = 0.2, 
                 use_bias = False):
        
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = PositionalEncoding(hidden_size, dropout)
        self.generator  = Generator(hidden_size, vocab_size)

        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module("block" + str(i), 
                                   DecoderBlock(hidden_size, hidden_size, hidden_size, hidden_size, ff_dim, num_heads, dropout , use_bias))
            
    def forward(self, x , enc_out, trg_mask, src_mask):
        x = self.pos_encoding(self.embedding(x) * np.sqrt(self.hidden_size))
        for block in self.blocks:
            x = block(x, enc_out, enc_out, trg_mask, src_mask)

        return self.generator(x)
    
decoder = Decoder(1024, 512, 2048, 8 , 8)
x = torch.ones((2,100), dtype = torch.long)
enc_out = torch.randn(2,100,512)
assert decoder(x, enc_out,  None, None).shape == (2, 100, 1024)