In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

"""
Setting some hyperparameters
"""
# size of the query, key, value and z vectors 
# in the attention layer (64 was used in paper)
ATTENTION_OUTPUT_DIM = 64
WORD_EMBEDDING_DIM = 512
device = "cuda" if torch.cuda.is_available() else "cpu"

def positional_encoding(seq_len, input_dim, device):
    """ FROM https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51
    TODO understand this.
    """
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(input_dim, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / 10000 ** (dim // input_dim)

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

class TwoLayerNN(nn.Module):
    """ Two layer network, without activation function for the last layer like in the paper
    """
    def __init__(self, input_dim, hidden_dim=2048):
        super().__init__()
        self.dense_1 = nn.Linear(input_dim, hidden_dim)
        self.dense_2 = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x):
        x = self.dense_1(x)
        x = F.relu(x)
        x = self.dense_2(x)
        return x
        
def scaled_dot_product_attention(queries, keys, values, mask=None):
    scores = queries @ torch.transpose(keys, 1, 2)
    if mask != None:
        scores.masked_fill(mask == 0, -1e9)
    scaled_scores = scores / np.sqrt(keys.shape[1])
    softmax_scores = F.softmax(scaled_scores, dim=1)
    z = softmax_scores @ values
    return z

class AttentionLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.W_q = nn.Linear(input_dim, output_dim, bias=False)
        self.W_k = nn.Linear(input_dim, output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, output_dim, bias=False)
        
    def forward(self, queries_based, keys_based, values_based, mask=None):
        queries = self.W_q(values_based)
        keys = self.W_k(keys_based)
        values = self.W_v(values_based)
        attention = scaled_dot_product_attention(queries, keys, values, mask)
        return attention
    
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dim, num_heads):
        """
        input_dim : dim of the input, output will have same dimension
        num_heads : number of self attention layers to concatenate, each will have a 
        query, key and value dimension of num_heads // input_dim
        """
        super().__init__()
        self.head_dim = input_dim // num_heads
        self.num_heads = num_heads
        self.attention_layers = nn.ModuleList(
            [AttentionLayer(input_dim, self.head_dim) for _ in range(num_heads)]
        )
        self.W_o = nn.Linear(self.head_dim * num_heads, input_dim, bias=False)
        
    def forward(self, queries_based, keys_based, values_based, mask=None):
        attentions = []
        for attention_layer in self.attention_layers:
            attention = attention_layer(queries_based, keys_based, values_based, mask)
            attentions.append(attention)
        z = torch.cat(attentions, dim=2)
        output = self.W_o(z)
        return z

class TransformerBlock(nn.Module):
    def __init__(self, input_dim, num_heads=8, dropout_p=0.1, ff_hidden_dim=2048):
        super().__init__()
        self.attention = MultiHeadAttentionLayer(input_dim, num_heads)
        self.feed_forward = TwoLayerNN(input_dim, ff_hidden_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.norm_1 = nn.LayerNorm(input_dim)
        self.norm_2 = nn.LayerNorm(input_dim)

    def forward(self, queries_based, keys_based, values_based):
        attention = self.attention(queries_based, keys_based, values_based)
        attention = self.dropout(attention)
        z1 = self.norm_1(attention + values_based)
        z2 = self.feed_forward(z1)
        z2 = self.dropout(z2)
        z2 = self.norm_2(z1 + z2)
        return z2
    
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads=8, dropout_p=0.1, ff_hidden_dim=2048):
        super().__init__()
        self.transformer_block = TransformerBlock(
            input_dim, num_heads, dropout_p, ff_hidden_dim
        )
        
    def forward(self, x):
        z = self.transformer_block(x, x, x)
        return z
    
class DecoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dropout_p=0.1, ff_hidden_dim=2048):
        super().__init__()
        self.W_q = nn.Linear(input_dim, input_dim, bias=False)
        self.W_k = nn.Linear(input_dim, input_dim, bias=False)
        self.W_v = nn.Linear(input_dim, input_dim, bias=False)
        self.attention = MultiHeadAttentionLayer(input_dim, num_heads)
        self.dropout = nn.Dropout(dropout_p)
        self.norm = nn.LayerNorm(input_dim)
        self.transformer_block = TransformerBlock(
            input_dim, num_heads, dropout_p, ff_hidden_dim
        )
        
    def forward(self, x_encoder, x_decoder, mask):
        attention = self.attention(x_decoder, x_encoder, x_encoder, mask)
        attention = self.dropout(attention)
        z1 = self.norm(attention + x_decoder)
        z2 = self.transformer_block(x_encoder, x_encoder, z1)
        return z2

class Encoder(nn.Module):
    def __init__(self, 
                 vocab_len, 
                 device, 
                 num_heads=8, 
                 embedding_dim=512, 
                 n_layers=6, 
                 dropout_p=0.1, 
                 ff_hidden_dim=2048):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_len, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(embedding_dim, num_heads, dropout_p, ff_hidden_dim) for i in range(n_layers)]
        )
        self.embedding_dim = embedding_dim
        self.device = device
        
    def forward(self, x):
        x = self.word_embedding(x)
        pos_enc = positional_encoding(x.shape[1], self.embedding_dim, self.device)
        x = x + pos_enc
        x = self.dropout(x)
        for block in self.encoder_blocks:
            x = block(x)
        return x

class Decoder(nn.Module):
    def __init__(self, 
                 vocab_len, 
                 device, 
                 num_heads=8, 
                 embedding_dim=512, 
                 n_layers=6, 
                 dropout_p=0.1, 
                 ff_hidden_dim=2048):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_len, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.decoder_blocks = nn.ModuleList(
            [DecoderBlock(embedding_dim, num_heads, dropout_p, ff_hidden_dim) for i in range(n_layers)]
        )
        self.embedding_dim = embedding_dim
        self.device = device
        self.fc_out = nn.Linear(embedding_dim, vocab_len)
        
        
    def forward(self, x_encoder, x_decoder, mask):
        x = self.word_embedding(x_decoder)
        pos_enc = positional_encoding(x.shape[1], self.embedding_dim, self.device)
        x = x + pos_enc
        x = self.dropout(x)
        for block in self.decoder_blocks:
            x = block(x_encoder, x, mask)
        x = self.fc_out(x)
        x = F.softmax(x, dim=1)
        return x

class Transformer(nn.Module):
    def __init__(self, 
                 vocab_len_src,
                 vocab_len_trg,
                 max_len,
                 device, 
                 num_heads=8, 
                 embedding_dim=512, 
                 n_layers_encoder=6,
                 n_layers_decoder=6,
                 dropout_p_encoder=0.1,
                 dropout_p_decoder=0.1, 
                 ff_hidden_dim_encoder=2048,
                 ff_hidden_dim_decoder=2048):
        super().__init__()
        self.encoder = Encoder(vocab_len_src, 
                               device, 
                               num_heads, 
                               embedding_dim, 
                               n_layers_encoder, 
                               dropout_p_encoder, 
                               ff_hidden_dim_encoder)
        self.decoder = Decoder(vocab_len_trg, 
                               device, 
                               num_heads, 
                               embedding_dim, 
                               n_layers_decoder, 
                               dropout_p_decoder, 
                               ff_hidden_dim_decoder)
        self.device = device
        self.max_len = max_len
        
    def forward(self, src, trg):
        mask = torch.tril(torch.ones((trg.shape[1], trg.shape[1]))).to(device)
        mask = mask.unsqueeze(0)
        z = self.encoder(src)
        y_pred = self.decoder(z, trg, mask)
        return y_pred

In [2]:
vocab_src_len = 300
vocab_trg_len = 400
max_len = 60
batch_size = 300
ff_hidden_dim = 1024
n_layers_encoder = 3
n_layers_decoder = 3

transformer = Transformer(
    vocab_src_len, 
    vocab_trg_len, 
    max_len, 
    device,
    n_layers_encoder=n_layers_encoder,
    n_layers_decoder=n_layers_decoder,
    ff_hidden_dim_encoder=ff_hidden_dim,
    ff_hidden_dim_decoder=ff_hidden_dim,
).to(device)

In [3]:
src_test = np.random.randint(0, vocab_src_len, size=(batch_size, max_len))
trg_test = np.random.randint(0, vocab_trg_len, size=(batch_size, max_len))

src_test = torch.Tensor(src_test).long().to(device)
trg_test = torch.Tensor(trg_test).long().to(device)

In [5]:
print(transformer(src_test, trg_test).shape)

torch.Size([300, 60, 400])
