In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
from torch import layer_norm

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        self.head_weights = []
        assert embed_size == self.head_dim * heads, "embedding size should be divisible by head dimension"

        for head in range(heads):
            self.query_weights = nn.Linear(self.embed_size, self.head_dim, bias=False)
            self.key_weights = nn.Linear(self.embed_size, self.head_dim, bias=False)
            self.value_weights = nn.Linear(self.embed_size, self.head_dim, bias=False)
            self.head_weights.append([self.query_weights, self.key_weights, self.value_weights])
        
        self.fc_out = nn.Linear(embed_size, embed_size, bias=False)


    def forward(self, query, key, value, mask):
        '''
        query_shape ---> (N, query_len, embed_size)
        key_shape ---> (N, key_len, embed_size)
        value_shape ---> (N, value_len, embed_size)
        query_len, key_len, value_len ---> number of tokens in a given sample sentence
        '''
        N = query.shape[0]
        q_len = query.shape[1]
        key_len = key.shape[1]
        val_len = value.shape[1]
        Attention_Heads = torch.zeros([N, q_len, self.head_dim, self.heads], dtype=torch.float64)
        softmax = nn.Softmax(dim=1)
        
        for head in range(self.heads):
            query_keyT_dot_product_by_sqrt_dk = torch.dot(self.head_weights[head][0](query), self.head_weights[head][1](key).T)/math.sqrt(self.head_dim)
            if mask is not None:
                query_keyT_dot_product_by_sqrt_dk = query_keyT_dot_product_by_sqrt_dk.masked_fill(mask == 0, float("-1e20"))
            Attention_Heads[:, :, :, head] = torch.matmul(softmax(query_keyT_dot_product_by_sqrt_dk), self.head_weights[head][2](value))
        Attention_heads = Attention_Heads.view(-1, q_len, self.heads*self.head_dim)
        MultiHead_Attention = self.fc_out(Attention_heads)
        
        return MultiHead_Attention

In [3]:
class add_and_norm_multihead(nn.Module):
    def __init__(self, embed_size):
        super(add_and_norm_multihead, self).__init__()
        self.embed_size = embed_size
        self.layer_norm_multihead = nn.LayerNorm(self.embed_size)

    def forward(self, multihead, prev_query):
        return self.layer_norm_multihead(torch.add(multihead, prev_query))

In [4]:
class add_and_norm_feedforward(nn.Module):
    def __init__(self, embed_size):
        super(add_and_norm_feedforward, self).__init__()
        self.embed_size = embed_size
        self.layer_norm_feedforward = nn.LayerNorm(self.embed_size)

    def forward(self, feedforward_output, prev_output):
        return self.layer_norm_feedforward(torch.add(feedforward_output, prev_output))

In [5]:
class feedforward(nn.Module):
    def __init__(self, embed_size, forward_expansion, dropout):
        super(feedforward, self).__init__()
        self.dropout = dropout
        self.embed_size = embed_size
        self.forward_expansion = forward_expansion
        self.feed_forward = nn.Sequential(nn.Linear(embed_size, math.floor(embed_size * forward_expansion)), 
                      nn.ReLU(), 
                      nn.Linear(math.floor(forward_expansion*embed_size), embed_size))
        self.dropout = nn.Dropout(dropout)
    def forward(self, normalized_multihead):
        x = self.dropout(normalized_multihead)
        out = self.feed_forward(x)
        return x

In [6]:
class encoder_block(nn.Module):
    def __init__(self, embed_size, forward_expansion, dropout, heads, mask=None):
        super(encoder_block, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.self_attention = SelfAttention(embed_size, heads)
        self.add_and_norm_multihead = add_and_norm_multihead(embed_size)
        self.add_and_norm_feedforward = add_and_norm_feedforward(embed_size)
        self.feed_forward = feedforward(embed_size, forward_expansion, dropout)
        self.mask = mask
    
    def forward(self, query, key, value):
        multihead = self.self_attention.forward(query, key, value, self.mask)
        normalized_multihead = self.add_and_norm_multihead.forward(multihead, query)
        feed_forward_output = self.feed_forward.forward(normalized_multihead)
        normalized_feed_forward_output = self.add_and_norm_feedforward.forward(feed_forward_output, normalized_multihead)
        normalized_feedforward_output_after_dropout = self.dropout(normalized_feed_forward_output)
        return normalized_feedforward_output_after_dropout

In [None]:
class decoder_block(nn.Module):
    def __init__(self, embed_size, forward_expansion, dropout, heads, encoding, mask):
        super(decoder_block, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.self_attention_1 = SelfAttention(embed_size, heads)
        self.self_attention_2 = SelfAttention(embed_size, heads)
        self.add_and_norm_multihead = add_and_norm_multihead(embed_size)
        self.add_and_norm_feedforward = add_and_norm_feedforward(embed_size)
        self.feed_forward = feedforward(embed_size, forward_expansion, dropout)
        self.mask = mask
        self.encoding = encoding
        self.key_weights = nn.Linear(embed_size, embed_size)
        self.value_weights = nn.Linear(embed_size, embed_size)
        

    def forward(self, query, key, value, mask):
        multihead_1 = self.self_attention_1.forward(query, key, value, self.mask)
        normalized_multihead_1 = self.add_and_norm_multihead.forward(multihead_1, query)
        key_2 = self.key_weights(self.encoding)
        value_2 = self.value_weights(self.encoding)
        multihead_2 = self.self_attention_2.forward(normalized_multihead_1, key_2, value_2, mask=None)
        normalized_multihead_2 = self.add_and_norm_multihead.forward(multihead_2, normalized_multihead_1)
        feed_forward_output = self.feed_forward.forward(normalized_multihead_2)
        normalized_feed_forward_output = self.add_and_norm_feedforward.forward(feed_forward_output, normalized_multihead_2)
        normalized_feedforward_output_after_dropout = self.dropout(normalized_feed_forward_output)
        return normalized_feedforward_output_after_dropout

In [None]:
class encoder(nn.Module):
    def __init__(self, embedding, layers=6, heads=8, forward_expansion=1.6, dropout=0.4):
        # embedding shape ---> (batch_size, max_len, embed_size)
        super(encoder, self).__init__()
        self.embedding = embedding
        self.embed_size = embedding.shape[2]
        self.layers = layers
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.dropout = dropout
        self.enc_blocks = nn.ModuleList([encoder_block(self.embed_size, self.forward_expansion, self.dropout, self.heads) for i in range(self.layers)])

    def forward(self):
        x = self.enc_blocks[0](self.embedding)
        for i in range(1, self.layers):
            x = self.enc_blocks[i](x)
        return x

In [None]:
class decoder(nn.Module):
    def __init__(self, input, encoding, mask, layers=6, heads=8, forward_expansion=1.6, dropout=0.4):
        # encoding shape ---> (batch_size, max_len, embed_size)
        # input shape ---> (batch_size, max_len, embed_size) {if n words are generated, in case of NLG, (max_len - n) words are masked}
        super(encoder, self).__init__()
        self.input = input
        self.encoding = encoding
        self.embed_size = encoding.shape[2]
        self.layers = layers
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.dropout = dropout
        self.dec_blocks = nn.ModuleList([decoder_block(self.embed_size, forward_expansion, dropout, heads, encoding, mask) for i in range(self.layers)])

    def forward(self):
        x = self.dec_blocks[0](self.input)
        for i in range(1, self.layers):
            x = self.dec_blocks[i](x)
        return x

In [None]:
class decoder_classification(nn.Module):
    def __init__(self, input, encoding, layers=6, heads=8, forward_expansion=1.6, dropout=0.4):
        # encoding shape ---> (batch_size, max_len, embed_size)
        # input shape ---> (batch_size, max_len, embed_size) {if n words are generated, in case of NLG, (max_len - n) words are masked}
        super(encoder, self).__init__()
        self.input = input
        self.encoding = encoding
        self.embed_size = encoding.shape[2]
        self.layers = layers
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.dropout = dropout
        self.dec_blocks = nn.ModuleList([decoder_block(self.embed_size, forward_expansion, dropout, heads, encoding) for i in range(self.layers)])

    def forward(self):
        x = self.dec_blocks[0](self.input)
        for i in range(1, self.layers):
            x = self.dec_blocks[i](x)
        return x

In [None]:
class transformer(nn.Module):
    def __init__(self, input_embedding_batch, )