<a href="https://colab.research.google.com/github/omarramy-74/TransformerFromScratch/blob/main/Transformers_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Transformers from scratch**

In [15]:
import torch
import torch.nn as nn
import math
import numpy as np

In [16]:
class input_emb(nn.Module):
  def __init__(self,d_model:int,vocab_size:int):
    self.d_model=d_model
    self.vocab_size=vocab_size
    self.emb = nn.Embedding(vocab_size,d_model)
  def forward(self,x):
    return self.emb(x)*math.sqrt(self.d_model)

In [17]:
class positional_encoding(nn.Module):
  def __init__(self,d_model:int,seq:int,dropout:int):
    self.d_model=d_model
    self.seq=seq
    self.dropout=nn.Dropout(dropout)
    mat = torch.zeros((seq,d_model))
    pos = torch.arange(0,seq).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    mat[:,0::2] = torch.sin(pos/div)
    mat[:,1::2] = torch.cos(pos/div)
    mat = mat.unsqueeze(0)
    self.register_buffer('mat',mat)
  def forward(self,x):
    x = x+(self.mat[:,:x.shape[1],:]).requires_grad_(False)
    return self.dropout(x)

In [18]:
class LayerNorm(nn.Module):
  def __init__(self,features,epsilon = 10**-6):
    super().__init__()
    self.epsilon = epsilon
    self.features = features
    self.gamma = nn.Parameter(torch.ones(features))
    self.beta = nn.Parameter(torch.zeros(features))
  def forward(self,x):
    mean = x.mean(-1,keepdim=True)
    std = x.std(-1,keepdim=True)
    self.x = self.gamma * ((self.x - mean)/(std + self.epsilon)) + self.beta
    return self.x

In [19]:
class FF(nn.Module):
  def __init__(self,d_model,dff,dropout):
    super().__init__()
    self.linear1 = nn.Linear(d_model,dff)
    self.linear2 = nn.Linear(dff,d_model)
    self.relu = nn.ReLU()
  def forward(self,x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)
    return x

In [20]:
class multiheadatt(nn.Module):
  def __init__(self,d_model,h,dropout):
    super().__init__()
    self.d_model = d_model
    self.h = h
    self.dropout = nn.Dropout(dropout)
    assert(d_model%h == 0)
    self.dim = d_model//h
    self.w_q = nn.Linear(d_model,d_model,bias=False)
    self.w_k = nn.Linear(d_model,d_model,bias=False)
    self.w_v = nn.Linear(d_model,d_model,bias=False)
    self.w_o = nn.Linear(d_model,d_model,bias=False)
  def attention(q,k,v,mask,dropout:nn.Dropout):
    d_k = q.shape[-1]
    att_scores = (q@k.transpose(-2,-1))/math.sqrt(d_k)
    att_scores = att_scores.softmax(dim=-1)
    return (att_scores@v),att_scores
  def forward(self,q,k,v,mask):
    q = self.w_q(q)
    k = self.w_k(k)
    v = self.w_v(v)
    q = q.view(q.shape[0], q.shape[1], self.h, self.d_k).transpose(1, 2)
    k = k.view(k.shape[0], k.shape[1], self.h, self.d_k).transpose(1, 2)
    v = v.view(v.shape[0], v.shape[1], self.h, self.d_k).transpose(1,2)
    x,self.att_scores = multiheadatt.attention(q,k,v,mask,self.dropout)
    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
    return self.w_o(x)

In [21]:
class resuidal_conn(nn.Module):
  def __init__(self,features,dropout):
    self.features = features
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNorm(features)
  def forward(self,x,sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [22]:
class Encoder_B(nn.Module):
  def __init__(self,features: int, self_attention_block: multiheadatt, feed_forward_block: FF, dropout: float):
    super.__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([resuidal_conn(features, dropout) for _ in range(2)])
  def forward(self,x,src_mask):
    x= self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,src_mask))
    x= self.residual_connections[1](x,self.feed_forward_block)
    return x

In [23]:
class Encoder(nn.Module):
  def __init__(self,features:int,layers:nn.ModuleList):
    self.features = features
    self.layers = layers
    self.norm = LayerNorm(features)
  def forward(self,x,mask):
    for layer in self.layers:
      x = layer(x,mask)
    return self.norm(x)

In [24]:
class projection(nn.Module):
  def __init__(self,d_model,vocab_size):
    self.linear = nn.Linear(d_model,vocab_size)
  def forward(self,x):
    return self.linear(x)

In [25]:
class DecoderBlock(nn.Module):
  def __init__(self, features: int, self_attention_block: multiheadatt, cross_attention_block: multiheadatt, feed_forward_block: FF, dropout: float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([resuidal_conn(features, dropout) for _ in range(3)])
  def forward(self, x, encoder_output, src_mask, tgt_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)
    return x

In [26]:
class Decoder(nn.Module):
  def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = torch.layer_norm(features)

  def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [27]:
class transformer(nn.Module):
  def __init__(self,encoder:Encoder,decoder:Decoder,src_embed:input_emb,tgt_embed:input_emb,src_pos:positional_encoding,tgt_pos:positional_encoding,projection_layer:projection):
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer
  def encode(self,src,src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src,src_mask)
  def decode(self,encoder_output,src_mask,tgt,tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt,encoder_output,src_mask,tgt_mask)
  def project(self,x):
    return self.projection_layer(x)

In [28]:
class build_transformer(nn.Module):
  def __init__(src_vocab_size: int, tgt_vocab_size: int, src_seq: int, tgt_seq: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048):
    src_embed = input_emb(d_model,src_vocab_size)
    tgt_embed = input_emb(d_model,tgt_vocab_size)
    src_pos = positional_encoding(d_model,src_seq,dropout)
    tgt_pos = positional_encoding(d_model,tgt_seq,dropout)
    encoder_blocks = []
    for _ in range(N):
      encoder_self_attention_block = multiheadatt(d_model,h,dropout)
      feed_forward_block = FF(d_model,d_ff,dropout)
      encoder_blocks.append(Encoder_B(d_model,encoder_self_attention_block,feed_forward_block))
    decoder_blocks = []
    for _ in range(N):
      decoder_self_attention_block = multiheadatt(d_model,h,dropout)
      decoder_cross_attention_block = multiheadatt(d_model,h,dropout)
      feed_forward_block = FF(d_model,d_ff,dropout)
      decoder_blocks.append(DecoderBlock(d_model,decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout))
    encoder = Encoder(d_model,nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model,nn.ModuleList(decoder_blocks))
    projection_layer = projection(d_model,tgt_vocab_size)
    Transformer = transformer(encoder,decoder,src_embed,tgt_embed,src_pos,tgt_pos,projection_layer)
    return Transformer