In [7]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import math
import copy

## Introduction

<img src="./transformer.png"/>

## Embedding

Embedding vectors will create a more semantic representation of each word.
Suppoese each embedding vector is of 512 dimension and suppose our vocab size is 100, then our embedding matrix will be of size 100x512. These marix will be learned on training and during inference each word will be mapped to corresponding 512 d vector. Suppose we have batch size of 32 and sequence length of 10(10 words). The the output will be 32x10x512.

In [8]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        """
        Args:
            vocab_size: size of the vocabulary
            emebed_dim: dimension of embedding
        """
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
    
    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: embedding vector
        """
        out = self.embedding(x)
        return out

## Positional Embedding

Next step is to generate positional encoding. Inorder for the model to make sense of the sentence, it needs to know two things about the each word.

    1. what does the word mean?
    2. what is the position of the word in the sentence.
    
In "attention is all you need paper" author used the following functions to create positional encoding. On odd time steps a cosine function is used and in even time steps a sine function is used.

<img src="./positional_embedding.png" style="width:350px; height:110px"/>

Positinal embedding will generate a matrix of similar to embedding matrix. It will create a matrix of dimension sequence length x embedding dimension. For each token(word) in sequence, we will find the embedding vector which is of dimension 1 x 512 and it is added with the correspondng positional vector which is of dimension 1 x 512 to get 1 x 512 dim out for each word/token.



<img src="./embedding_layer.png" style="width:350px;" />

In [9]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_sequence_len, embed_dim):
        """ 
        Args:
            max_sequence_len: length of input sentence
            embed_dim: dimension of embedding vector
        """
        super(PositionalEmbedding,self).__init__()
        
        self.embed_dim = embed_dim
        self.max_sequence_len = max_sequence_len
        
        pe = torch.zeros(max_sequence_len, self.embed_dim)

        for pos in range(max_sequence_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / 1000 ** ((2 * i) / self.embed_dim))
                pe[pos, i+1] = math.cos(pos / 1000 ** ( (2 * (i+1)) / self.embed_dim ))
                
        pe = pe.unsqueeze(0)
        # If you have parameters in your model, which should be saved and restored in the state_dict,
        # but not trained by the optimizer, you should register them as buffers.
        self.register_buffer('pe', pe)
        
    def forward(self,x):
        """ 
        Args:
            x: input vector
        Returns:
            x: output vector
        """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x

## Self Attention

In [10]:
def clones(module, n):
    """ Products n identical layers
    Args:
        module: The layer
        n: count of layers
    Returns
        list of modules
    """
    return nn.ModuleList([copy.deepcopy(module) for i in range(n)])

In [25]:
def attention_fn(query, key, value, mask=None):
    """ compute the scaled dot product attention
    Args:
        query: query vector
        key: key vector
        value: value vector
    """
    single_head_dim = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(single_head_dim)
   
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
        
    p_attn = scores.softmax(dim=-1)
    
    return torch.matmul(p_attn, value), p_attn 

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, n_heads=8):
        """ 
        Args:
            embed_dim: dimension of embedding vector
            n_heads: number of heads in multi-head attention
        """
        
        super(MultiHeadAttention,self).__init__()
        
        self.embed_dim = embed_dim #512
        self.n_heads = n_heads     #8
        self.single_head_dim = int(self.embed_dim / self.n_heads) # 512/8 = 64 each query,key,value vector will be of 64d
        self.linears = clones(nn.Linear(self.embed_dim, self.embed_dim), 4)
        self.attn = None
        
    def forward(self, query, key, value, mask=None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
            
        n_batches = query.size(0)
        # print("query shape in multihead attention", query.shape)
            
        query, key, value = [
            lin(x).view(n_batches, -1, self.n_heads, self.single_head_dim).transpose(1,2)
            for lin, x in zip(self.linears,(query, key, value))
        ]
        
        x, self.attn = attention_fn(query, key, value, mask)
        
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(n_batches, -1, self.n_heads * self.single_head_dim)
        )
        
        return self.linears[-1](x)

# Encoder

<img src="./encoder.png" style="height:450px; width:280px"/>

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor, n_heads):
        super(TransformerBlock, self).__init__()
        
        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: factor which determines output dimension of linear layer
           n_heads: number of attention heads
        """
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, expansion_factor*embed_dim),
            nn.ReLU(),
            nn.Linear(expansion_factor*embed_dim,embed_dim)
        )
        
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)
    
    def forward(self, query, key, value):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           norm2_out: output of transformer block
        
        """
        
        attention_out = self.attention(key,query,value)  #32x10x512
        attention_residual_out = attention_out + value  #32x10x512
        norm1_out = self.dropout1(self.norm1(attention_residual_out)) #32x10x512

        feed_fwd_out = self.feed_forward(norm1_out) #32x10x512 -> #32x10x2048 -> 32x10x512
        feed_fwd_residual_out = feed_fwd_out + norm1_out #32x10x512
        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out)) #32x10x512

        return norm2_out

In [14]:
class TransformerEncoder(nn.Module):
    
    """
    Args:
        seq_len : length of input sequence
        embed_dim: dimension of embedding
        num_layers: number of encoder layers
        expansion_factor: factor which determines number of linear layers in feed forward layer
        n_heads: number of heads in multihead attention
        
    Returns:
        out: output of the encoder
    """
    
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, n_heads=8, ):
        super(TransformerEncoder,self).__init__()
    
        self.embedding_layer = Embedding(vocab_size, embed_dim)
        self.positional_encode = PositionalEmbedding(seq_len, embed_dim)
        
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, expansion_factor, n_heads) for i in range(num_layers)])
    
    def forward(self,x):
        embedding_out = self.embedding_layer(x)
        positional_embed_out = self.positional_encode(embedding_out)
        
        for layer in self.layers:
            encoder_out = layer(positional_embed_out,positional_embed_out,positional_embed_out)

        return encoder_out 

# Decoder

<img src="./decoder.webp" />

In [15]:
class DecoderBlock(nn.Module):
    def __init__(self,embed_dim, expansion_factor=4, n_heads=8):
        super(DecoderBlock, self).__init__()
        
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           mask: mask to be given for multi head attention 
        Returns:
           out: output of transformer block
    
        """
        
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.2)
        self.transformer_block = TransformerBlock(embed_dim, expansion_factor, n_heads)
        
    def forward(self, key, query, x, mask):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           mask: mask to be given for multi head attention 
        Returns:
           output: output of transformer block
    
        """
        attention_out = self.attention(x, x, x, mask)
        print('here')
        value = self.dropout( self.norm( attention_out + x ) )
        output = self.transformer_block(key, query, value)
    
        return output

In [16]:
class TransformerDecoder(nn.Module):
    def __init__(self,target_vocab_size, embed_dim, seq_len, num_layers=2, expansion_factor=4, n_heads =8):
        super(TransformerDecoder, self).__init__()
        """  
        Args:
           target_vocab_size: vocabulary size of taget
           embed_dim: dimension of embedding
           seq_len : length of input sequence
           num_layers: number of encoder layers
           expansion_factor: factor which determines number of linear layers in feed forward layer
           n_heads: number of heads in multihead attention
        
        """
        self.embedding_layer = Embedding(target_vocab_size, embed_dim)
        self.positional_encode = PositionalEmbedding(seq_len, embed_dim)
        
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, expansion_factor, n_heads)
            for i in range(num_layers)
        ])
        
        self.fc_layer = nn.Linear(embed_dim, target_vocab_size)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x, encoder_out, mask):
        """
        Args:
            x: input vector from target
            encoder_out : output from encoder layer
            trg_mask: mask for decoder self attention
        Returns:
            output: output vector
        """
        
        x = self.embedding_layer(x)
        x = self.positional_encode(x)
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(encoder_out, x, encoder_out, mask)
        
        output = F.softmax(self.fc_layer(x))
        
        return output

# Transformer

In [26]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, seq_len, num_layers=2, expansion_factor=4, n_heads=8):
        super(Transformer,self).__init__()
        
        self.target_vocab_size = target_vocab_size
        self.encoder = TransformerEncoder(seq_len,src_vocab_size, embed_dim, num_layers,expansion_factor, n_heads)
        self.decoder = TransformerDecoder(target_vocab_size, embed_dim,seq_len, num_layers,expansion_factor, n_heads)
        
    def make_trg_mask(self, trg):
        """
        Args:
            trg: target sequence
        Returns:
            trg_mask: target mask
        """
        n_batches, trg_len = trg.shape
        
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            n_batches,trg_len, trg_len
        )
        return trg_mask    

    def forward(self, src, trg):
        """
        Args:
            src: input to encoder 
            trg: input to decoder
        out:
            out: final vector which returns probabilities of each target word
        """
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src)
        print(src.shape)

        outputs = self.decoder(trg, enc_out, trg_mask)
        return outputs



## Test Code

In [23]:
src_vocab_size = 1000
target_vocab_size = 1000
num_layers = 6
seq_len= 12
n_batches = 32


# let 0 be sos token and 1 be eos token
src = torch.randint(src_vocab_size, (n_batches, seq_len))
target = torch.randint(target_vocab_size, (n_batches, seq_len))

print(src.shape,target.shape)
model = Transformer(embed_dim=512, src_vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, seq_len=seq_len,
                    num_layers=num_layers, expansion_factor=4, n_heads=8)
# model

torch.Size([32, 12]) torch.Size([32, 12])


In [24]:
out = model(src, target)
out.shape

p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12])
value torch.Size([32, 8, 12, 64])
p attention torch.Size([32, 8, 12, 12]

  output = F.softmax(self.fc_layer(x))


torch.Size([32, 12, 1000])

In [709]:
def decode(self, src, trg):
        
    """
    Args:
        src: input to encoder 
        trg: input to decoder
    out:
        out_labels : returns final prediction of sequence
    """
    trg_mask = make_trg_mask(trg)
    encoder_output = self.encoder(src)
    out_labels = []
    n_batches, seq_len = src.shape[0], src.shape[1]
    output = trg
    
    for i in range(seq_len):
        output = self.decoder(output, encoder_output, trg_mask)
        output = output[:,-1,:]
        
        output = output.argmax(-1)
        out_labels.append(output.item())
        output = torch.unsqueeze(output,axis=0)
        
        return out_labels
    