# Attention is All You Need

In [1]:
#importing libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

import copy

In [2]:
def clone_layer(layer, n): 
    '''Function to create n copies of a layer'''
    return nn.ModuleList([copy.deepcopy(layer) for _ in range(n)])

In [3]:
def generate_target_mask(n):
    mask = torch.ones(1,n,n)
    mask = torch.tril(mask, diagonal=-1)
    
    return mask

### Attention Mechanism

<img style="float: left;" src="artifacts/attention.png" height=480 width=480>
<img style="float: right;" src="artifacts/scaled dot-product attention.png" height=280 width=280>

In [4]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.d_k = d_k
        
        self.query = nn.Linear(d_model, d_k)
        self.key = nn.Linear(d_model, d_k)
        self.value = nn.Linear(d_model, d_v)
    
    def forward(self, q, k ,v, mask=None):
        Q = self.query(q)
        K = self.key(k)
        V = self.value(v)
        
        attn_wt = (torch.matmul(Q, K.transpose(-2, -1)))/(self.d_k**0.5)
        
        if mask is not None:
            attn_wt.masked_fill_(mask==0, -1e9)
            
        score = F.softmax(attn_wt, dim=-1)
        
        attention = torch.matmul(score, V)
        
        return attention

<img style="float: left;" src="artifacts/Multi Head attention form.png" height=480 width=480>
<img style="float: right;" src="artifacts/Multi head attention.png" height=280 width=280>

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()
        
        self.h = heads
        self.d_model = d_model
        self.d_k = d_model // heads
        self.d_v = d_model // heads
    
        self.attns = clone_layer(AttentionHead(d_model, self.d_k, self.d_k), self.h)
        
        self.multi_head_out = nn.Linear(heads*self.d_v, d_model)
        
    def forward(self, q, k, v, mask=None):
        heads = [attn(q, k, v, mask) for attn in self.attns]  #(head1,...,head_h)
        
        concat_heads = torch.concat(heads, dim=-1)            #Concat(heads)
        
        return self.multi_head_out(concat_heads)              #Linear(Concatenated heads)

### Encoder
<img src="artifacts/Encoder.png" height=280 width=280>

In [6]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, attn_heads, d_ff, dropout=0.1):
        super(EncoderBlock, self).__init__()
        
        self.multi_head_attn  = MultiHeadAttention(attn_heads, d_model)
        self.layer_norm1 = nn.LayerNorm(d_model)
        
        self.layer_norm2 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attn_out = self.dropout(self.multi_head_attn(x,x,x, mask))  #Multi head Attention
        add_norm1 = self.layer_norm1(x+attn_out)                    #Add & Norm
        
        ff_out = self.dropout(self.ffn(add_norm1))                  #Feed Forward
        add_norm2 = self.layer_norm1(add_norm1+ff_out)              #Add & Norm
        
        return add_norm2

In [7]:
class Encoder(nn.Module):
    '''Creates a stack of Encoder Blocks of size num_layers'''
    def __init__(self, num_layers, d_model, attn_heads, d_ff, dropout=0.1):
        super(Encoder, self).__init__()
        
        self.encoders = clone_layer(EncoderBlock(d_model, attn_heads, d_ff, dropout), num_layers)
        
    def forward(self, x, mask):
        for layer in self.encoders:
            x = layer(x, mask)
        return x

### Decoder
<img src="artifacts/Decoder.png" height=280 width=280>

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, attn_heads, d_ff, dropout=0.1):
        super(DecoderBlock, self).__init__()
        
        self.masked_multi_head_attn  = MultiHeadAttention(attn_heads, d_model)
        self.multi_head_attn  = MultiHeadAttention(attn_heads, d_model)
        
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
    def forward(self, x, encd_out, mask):
        masked_attn_out = self.dropout1(self.masked_multi_head_attn(x,x,x, mask))      #Masked Multi-Head Attention
        add_norm1 = self.layer_norm1(x+masked_attn_out)                               #Add & Norm
        
        attn_out = self.dropout2(self.multi_head_attn(add_norm1, encd_out, encd_out))  #Encoder-Decoder Multi-Head Attention
        add_norm2 = self.layer_norm2(add_norm1+attn_out)                              #Add & Norm
        
        ff_out = self.dropout3(self.ffn(add_norm2))                                    #Feed Forward
        add_norm3 = self.layer_norm1(add_norm2+ff_out)                                #Add & Norm
        
        return add_norm3

In [9]:
class Decoder(nn.Module):
    '''Creates a stack of Decoder Blocks of size num_layers'''
    def __init__(self, num_layers, d_model, attn_heads, d_ff, dropout=0.1):
        super(Decoder, self).__init__()
        self.decoders = clone_layer(DecoderBlock(d_model, attn_heads, d_ff, dropout), num_layers)
        
    def forward(self, x, encdr_out, mask=None):
        for layer in self.decoders:
            x = layer(x, encdr_out, mask)
        return x

### Embeddings

<img style="float: left;" src="artifacts/positional encoding.png" height=380 width=380>
<img style="float: right;" src="artifacts/Encoding.png" height=380 width=380>

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5_000):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        
        pos_encd = torch.empty(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        
        denom = torch.exp(torch.arange(0, d_model, 2)*-(math.log(1000))/d_model)
        
        pos_encd[:, 0::2] = torch.sin(position*denom)
        pos_encd[:, 1::2] = torch.cos(position*denom)
        
        pos_encd = pos_encd.unsqueeze(0)
        
        self.register_buffer("pos_encd", pos_encd) #to include it in the models state dict and also to push it on the same device as the model which uses it
        
    def forward(self, x):
        out = x+self.pos_encd[:, :x.shape[1]].requires_grad_(False)
        return self.dropout(out)

In [11]:
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        out = self.embed(x)
        return out/(self.d_model**0.5)

### Complete Transformer
<img src="artifacts/Transformer.png" height=380 width=380>

In [12]:
class Transformer(nn.Module):
    def __init__(self, d_model, trgt_vocab, nheads=8,num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.d_model = d_model
        
        self.encoders = Encoder(num_encoder_layers, d_model, nheads, d_ff, dropout)
        self.decoders = Decoder(num_decoder_layers, d_model, nheads, d_ff, dropout)
        
        self.encoder_pos = PositionalEncoding(d_model, dropout)
        self.decoder_pos = PositionalEncoding(d_model, dropout)
        
        self.out = nn.Linear(d_model, trgt_vocab)
    
    def forward(self, src, trgt, trgt_mask, src_mask=None):
        encdr_inp = self.encoder_pos(src)
        dcdr_inp = self.decoder_pos(trgt)
        
        encdr_out = self.encoders(encdr_inp, src_mask)
        decoder_out = self.decoders(dcdr_inp, encdr_out, trgt_mask)
        
        out = self.out(decoder_out)
        return F.log_softmax(out, dim=-1)

In [13]:
d_model = 16

In [14]:
src = torch.randn(5, 10, d_model)
trgt = torch.randn(5, 12, d_model)
trgt_mask = generate_target_mask(12)

In [15]:
trnsfrmr = Transformer(d_model, 20)

In [16]:
out = trnsfrmr(src, trgt, trgt_mask)
out.shape

Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size([5, 10, 10])
value shape:  torch.Size([5, 10, 2])
Score shape:  torch.Size(

torch.Size([5, 12, 20])

<b>Note</b>: `source` and `target embedding size` should be same as `d_model`. Both these embeddings could be trained using `InputEmbedding` or we canuse pretrained embeddings.

Refernce: [Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)