In [None]:
import torch
import torch.nn as nn
import math

# Encoder Block

In [None]:
class InputEmbedding(nn.Module):
  def __init__(self, d_model:int, vocab_size:int):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, d_model)
  def forward():
    return self.embedding(x) * math.sqrt(self.d_model)

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model:int, seq_len:int,dropout:float):
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout(dropout)
    #Create a matrix (seq_len,d_model)
    pe = torch.zeros(self.seq_len, self.d_model)
    #Create a vector of (seq_len,1)
    position = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
    #Denominator part
    division = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
    pe[:,0::2] = torch.sin(position * division)
    pe[:,1::2] = torch.cos(position * division)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe',pe)
  def forward(self,x):
    x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
    return self.dropout(x)

In [None]:
class LayerNorm(nn.Module):
  def __init__(self,eps:float=10**-6)-> None:
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1)) #Multiplied
    self.beta = nn.Parameter(torch.zeros(1)) # Added
  def forward(self,x):
    mean = x.mean(dim=-1,keepdim=True)
    var = x.var(dim=-1,keepdim=True)
    return self.alpha*(x-mean)/torch.sqrt(var+self.eps) + self.bias

In [None]:
class FeedForwardBlock(nn.Module):
  def __init__(self,d_model:int,d_ff:int,dropout:float):
    super().__init__()
    self.linear_1 = nn.Linear(d_model,d_ff) #W1 and B1
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff,d_model) #W2 and B2
  def forward(self,x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model:int,h:int,dropout:float) -> None:
    super().__init__()
    self.d_model = d_model
    self.h = h
    assert d_model%h == 0, "d_model is not divisible by h"

    self.d_k = d_model//h
    self.w_q = nn.Linear(d_model,d_model)
    self.w_k = nn.Linear(d_model,d_model)
    self.w_v = nn.Linear(d_model,d_model)

    self.w_o = nn.Linear(d_model,d_model)
    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(query,key,value,mask,dropout:nn.Dropout):
    #query shape: (batch,h,seq_len,d_k)
    d_k = query.shape[-1]
    #shape: (batch,h,seq_len,d_k) * (batch,h,d_k,seq_len) -> (batch,h,seq_len,seq_len)
    attention_scores = (query @ key.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
      attention_scores.masked_fill_(mask==0,-1e9)
    attention_scores = attention_scores.softmax(dim=-1)
    if dropout is not None:
      attention_scores = dropout(attention_scores)
    #(attention_scores @ value) -> (batch,h,seq_len,d_k)
    return (attention_scores @ value), attention_scores

  def forward(self,q,k,v,mask):
    #shape:(batch,seq_len,d_model) -> (batch,seq_len,d_model)
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)
    #shape: (batch,seq_len,d_model) -> (batch,seq_len,h,d_k) -> (batch,h,seq_len,d_k)
    query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
    key = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
    value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)
    #shape: x: (batch,h,seq_len,d_k)
    x, self.attention_scores = MultiHeadAttention.attention(query,key,value,mask,self.dropout)
    #shape: (batch,h,seq_len,d_k) -> (batch,seq_len,h,d_k) -> concat -> (batch,seq_len,d_model)
    x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h*self.d_k)
    return self.w_o(x)

In [None]:
class ResidualConnection(nn.Module):
  def __init__(self,dropout:float) -> None:
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNorm()
  def forward(self,x,sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self,self_attention_block:MultiHeadAttention,feed_forward_block:FeedForwardBlock,dropout:float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.dropout = dropout
    self.residual1 = ResidualConnection(dropout)
    self.residual2 = ResidualConnection(dropout)
  def forward(self, x, mask):
    # First residual connection around self-attention
    x = self.residual1(x, lambda x: self.self_attention_block(x, x, x, mask))

    # Second residual connection around feed-forward
    x = self.residual2(x, self.feed_forward_block)
    return x

In [None]:
class Encoder(nn.Module):
  def __init__(self,layers:nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()
  def forward(self,x,mask):
    for layer in self.layers:
      x = layer(x,mask)
    return self.norm(x)

# Decoder Block


In [None]:
class DecoderBlock(nn.Module):
  def __init__(self,self_attention_block:MultiHeadAttention,cross_attention_block:MultiHeadAttention,feed_forward_block:FeedForwardBlock,dropout:float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.dropout = dropout
    self.residual1 = ResidualConnection(dropout)
    self.residual2 = ResidualConnection(dropout)
    self.residual3 = ResidualConnection(dropout)
  def forward(self,x,encoder_output,src_mask,tgt_mask):
    x = self.residual1(x,lambda x: self.self_attention_block(x,x,x,tgt_mask))
    x = self.residual2(x,lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
    x = self.residual3(x,self.feed_forward_block)
    return x

In [None]:
class Decoder(nn.Module):
  def __init__(self,layers:nn.ModuleList)-> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()
  def forward(self,x,encoder_output,src_mask,tgt_mask):
    for layer in layers:
      x = layer(x,encoder_output,src_mask,tgt_mask)
    return self.norm(x)

In [None]:
class ProjectionLayer(nn.Module):
  def __init__(self,d_model:int,vocab_size:int) -> None:
    super().__init__()
    self.proj = nn.Linear(d_model,vocab_size)
  def forward(self,x):
    # (batch,seq_len,d_model) -> (batch,seq_len,vocab_size)
    return torch.log_softmax(self.proj(x),dim=-1)

# Transformer Block

In [None]:
class Transformer(nn.Module):
  def __init__(self,encoder: Encoder,decoder:Decoder,src_embedding: InputEmbedding, tgt_embedding: InputEmbedding,src_pos: PositionalEncoding,tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embedding = src_embedding
    self.tgt_embedding = tgt_embedding
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer
  def encode(self,src,src_mask):
    src = self.src_embedding(src)
    src = self.src_pos(src)
    return self.encoder(src,src_mask)
  def decode(self,encoder_output,src_mask,tgt,tgt_mask):
    tgt = self.tgt_embedding(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt,encoder_output,src_mask,tgt_mask)
  def project(self,x):
    return self.projection_layer(x)

# The main Function

In [None]:
def build_transformer(src_vocab_size:int,tgt_vocab_size:int,src_seq_len:int,tgt_seq_len:int,d_model:int=512,N:int=6,h:int=8,dropout:float=0.1,d_ff:int=2048) -> Transformer:
  #Input Embedding
  src_embedding = InputEmbedding(d_model,src_vocab_size)
  tgt_embeding = InputEmbedding(d_model,tgt_vocab_size)
  #Positional Embedding
  src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
  tgt_pos = PositionalEncoding(d_model,tgt_seq_len,dropout)
  #Encoder block
  encoder_blocks = []
  for _ in range(N):
    encoder_self_attention_block = MultiHeadAttention(d_model,h,dropout)
    feed_forward_block = FeedForwardBlock(d_model,d_ff,dropout)
    encoder_blocks.append(EncoderBlock(
        encoder_self_attention_block,
        feed_forward_block,
        dropout))
  #Decoder blocks
  decoder_blocks = []
  for _ in range(N):
    decoder_self_attention_block = MultiHeadAttention(d_mode,h,dropout)
    decoder_cross_attention_block = MultiHeadAttention(d_model,h,dropout)
    feed_forward_block = FeedForwardBlock(d_model,d_ff,dropout)
    decoder_blocks.append(DecoderBlock(
        decoder_self_attention_block,
        decoder_cross_attention_block,
        feed_forward_block,
        dropout))
  #Create Encoder and Decoder
  encoder = Encoder(nn.ModuleList(encoder_blocks))
  decoder = Decoder(nn.ModuleList(decoder_blocks))
  #Create the projection layer
  projection_layer = ProjectionLayer(d_model,tgt_vocab_size)
  #Create the transformer
  transformer = Transformer(encoder,decoder,src_embedding, tgt_embedding,src_pos,tgt_pos, projection_layer)
  #Initialize the parameters
  for p in transformer.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)

  return transformer