In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math,copy,re
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import torchtext
import matplotlib.pyplot as plt
warnings.simplefilter("ignore")
print(torch.__version__)

2.0.1


In [2]:
# first creating the embeddings of size ___ to embed the tokens
class Embeddings(nn.Module):
    def __init__(self, vocabulary_size, embedding_dimension):
        super(Embeddings, self).__init__()
        self.embed = nn.Embedding(vocabulary_size, embedding_dimension)
    def forward(self, X):
        embeddings = self.embed(X)
        return embeddings

In [3]:
# second creating Positional Encoding using the sinusodial method
class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len,embed_model_dim):
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len,self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0,self.embed_dim,2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x * math.sqrt(self.embed_dim)
        seq_len = x.size(1)    
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        return x

In [4]:
# thrid creating the MHA 
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension=512, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dimension = embedding_dimension
        self.num_heads = num_heads
        self.single_head_dimension = int(self.embedding_dimension / self.num_heads)
        self.Query_matrix = nn.Linear(self.single_head_dimension, self.single_head_dimension, bias=False)
        self.Key_matrix = nn.Linear(self.single_head_dimension, self.single_head_dimension, bias=False)
        self.Value_matrix = nn.Linear(self.single_head_dimension, self.single_head_dimension, bias=False)
        self.out = nn.Linear(self.num_heads*self.single_head_dimension ,self.embedding_dimension)
    def forward(self, K, Q, V, mask=None):
        batch_size = K.size(0) # input shape is Batchsize * sequencelenght * embedding_dimension
        sequence_lenght = K.size(1)
        
        sequence_length_query = Q.size(1)
        K = K.view(batch_size, sequence_lenght, self.num_heads, self.single_head_dimension)
        # input shape is [batchsize * sequence_length * num_heads * single_head_dimension]
        # [32*10*8*64]
        Q = Q.view(batch_size, sequence_lenght, self.num_heads, self.single_head_dimension)
        V = V.view(batch_size, sequence_lenght, self.num_heads, self.single_head_dimension)
        
        key = self.Key_matrix(K)
        query = self.Query_matrix(Q)
        value = self.Value_matrix(V)
        
        query = query.transpose(1,2)
        #(batch_size , num_heads, seq_len, single_head_dimension)
        #(32 * 8 * 10 * 64)
        key = key.transpose(1,2)
        value = value.transpose(1,2)
        
        key_adjusted = key.transpose(-1,-2)
        #(batch_size, num_heads, single_head_dim, seq_len)
        #(32 * 8 * 64 * 10)
        
        dot_product = torch.matmul(query, key_adjusted)
        
        if mask is not None:
            dot_product = dot_product.masked_fill(mask == 0, float("-1e20"))
        
        dot_product = dot_product / math.sqrt(self.single_head_dimension)
        attention_scores = F.softmax(dot_product, dim=-1)
        attention_scores = torch.matmul(attention_scores, value)
        #(32x8x 10x 10) x (32 x 8 x 10 x 64) = (32 x 8 x 10 x 64) 
        
        concated_output = attention_scores.transpose(1,2).contiguous().view(
            batch_size,
            sequence_length_query,
            self.single_head_dimension * self.num_heads
        )  # (32x8x10x64) -> (32x10x8x64)  -> (32,10,512)
        
        concated_output = self.out(concated_output)
        return concated_output

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, embedding_dimension, expansion_factor=4, num_heads=8):
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embedding_dimension, num_heads)
        
        self.norm1 = nn.LayerNorm(embedding_dimension)
        self.norm2 = nn.LayerNorm(embedding_dimension)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dimension, expansion_factor*embedding_dimension),
            nn.ReLU(),
            nn.Linear(expansion_factor*embedding_dimension, embedding_dimension)
        )
        
        
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)
    
    def forward(self, key, query, value):
        attention_output = self.attention(key, query, value)
        attention_residual_output = attention_output + value
        norm1_output = self.dropout1(self.norm1(attention_residual_output))
        
        feed_forward_output = self.feed_forward(norm1_output)
        feed_forward_residual_output = feed_forward_output + norm1_output
        norm2_output = self.dropout2(self.norm2(feed_forward_residual_output))
        
        return norm2_output

In [6]:
class Encoder(nn.Module):
    def __init__(self, sequence_length, vocabulary_size, embedding_dimension, num_layers=2 ,expansion_factor=4,num_heads=8):
        super(Encoder, self).__init__()
        
        self.embedding_layer = Embeddings(vocabulary_size, embedding_dimension)
        self.positional_encoder = PositionalEmbedding(sequence_length, embedding_dimension)
        
        self.layers = nn.ModuleList([EncoderBlock(embedding_dimension, expansion_factor, num_heads) for i in range(num_layers)])
        
    def forward(self, X):
        embedding_output = self.embedding_layer(X)
        positional_output = self.positional_encoder(embedding_output)
        for layer in self.layers:
            final_output = layer(positional_output, positional_output, positional_output)
        return final_output
    

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dimension, expansion_factor=4, num_heads=8):
        super(DecoderBlock, self).__init__()
        
        self.attention = MultiHeadAttention(embedding_dimension, num_heads=8)
        self.norm = nn.LayerNorm(embedding_dimension)
        self.dropout = nn.Dropout(0.2)
        self.encoder_block = EncoderBlock(embedding_dimension, expansion_factor, num_heads)
        
    def forward(self, key, query, value, mask):
        attention = self.attention(query, key, value, mask=mask)
        value = self.dropout(self.norm(attention + value))
        
        output = self.encoder_block(key, query, value)
        
        return output

In [8]:
class Decoder(nn.Module):
    def __init__(self, target_vocab_size, embedding_dimension, sequence_length, num_layers = 2, expansion_factor=4, num_heads=8):
        super(Decoder, self).__init__()
        
        self.word_embedding = nn.Embedding(target_vocab_size, embedding_dimension)
        self.positonal_embedding = PositionalEmbedding(sequence_length, embedding_dimension)
        
        self.layers = nn.ModuleList([
            DecoderBlock(embedding_dimension, expansion_factor=4, num_heads=8)
            for i in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(embedding_dimension, target_vocab_size)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, X, encoder_output, mask):
        x = self.word_embedding(X)
        x = self.positonal_embedding(x)
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(encoder_output, x, encoder_output, mask) 
        
        output = F.softmax(self.fc_out(x))
        
        return output
        

In [9]:
class Transformer(nn.Module):
    def __init__(self, embedding_dimension, source_vocab_size, target_vocab_size,
                sequence_length, num_layers=2, expansion_factor=4, num_heads=8):
        super(Transformer, self).__init__()
        
        
        self.target_vocab_size = target_vocab_size
        
        self.encoder = Encoder(
            sequence_length=sequence_length,
            vocabulary_size=source_vocab_size,
            embedding_dimension=embedding_dimension,
            num_layers=num_layers,
            expansion_factor=expansion_factor,
            num_heads=num_heads
        )
        self.decoder = Decoder(
            target_vocab_size=target_vocab_size,
            embedding_dimension=embedding_dimension,
            sequence_length=sequence_length,
            num_layers=num_layers,
            expansion_factor=expansion_factor,
            num_heads=num_heads
        )
    
    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        )
        
        return trg_mask
    
    def decode(self, source, target):
        target_mask = self.make_trg_mask(target)
        encoder_output = self.encoder(source)
        
        out_labels = []
        batch_size , sequence_length = source.shape[0], source.shape[1]
        
        output = target
        
        for i in range(sequence_length):
            output = self.decoder(output, encoder_output, target_mask)
            
            output = output[:, -1, :]
            output = output.argmax(-1)
            out_labels.append(output.item())
            output = torch.unsqueeze(output, axis=0)
        return out_labels
    
    def forward(self, source, target):
        trg_mask = self.make_trg_mask(target)
        enc_out = self.encoder(source)
        
        outputs = self.decoder(target, enc_out, trg_mask)
        return outputs

In [10]:
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12



src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape,target.shape)
model = Transformer(
                    embedding_dimension=512, 
                    source_vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, 
                    sequence_length=seq_length,
                    num_layers=num_layers, 
                    expansion_factor=4, 
                    num_heads=8
                   )
model

torch.Size([2, 12]) torch.Size([2, 12])


Transformer(
  (encoder): Encoder(
    (embedding_layer): Embeddings(
      (embed): Embedding(11, 512)
    )
    (positional_encoder): PositionalEmbedding()
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (attention): MultiHeadAttention(
          (Query_matrix): Linear(in_features=64, out_features=64, bias=False)
          (Key_matrix): Linear(in_features=64, out_features=64, bias=False)
          (Value_matrix): Linear(in_features=64, out_features=64, bias=False)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2,