<a href="https://colab.research.google.com/github/efealioksuz/Transformer-Base-Model/blob/main/BasicTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
import copy

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads=8):
    super(MultiHeadAttention, self).__init__()
    
    self.n_heads = n_heads
    self.d_key = d_model // n_heads
    
    self.W_Q = nn.Linear(d_model, d_model)
    self.W_K = copy.deepcopy(self.W_Q)
    self.W_V = copy.deepcopy(self.W_Q)
    self.W_O = copy.deepcopy(self.W_Q)
  
  def Attention(self, query, key, value, mask=None):
    d = query.size(-1)
    x = (torch.matmul(query, key.transpose(-2,-1)))/math.sqrt(d)
    if mask != None:
      x = x.masked_fill(mask == 0, float('-inf'))
    x = x.softmax(-1)
    x = torch.matmul(x, value)
    return x
    
  def forward(self, query, key, value, mask=None):
    batch_size = query.shape[0]

    query = self.W_Q(query)
    parallel_queries = query.reshape(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)

    key = self.W_K(key)
    parallel_keys = key.reshape(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)

    value = self.W_V(value)
    parallel_values = value.reshape(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)

    if mask != None:
      mask = mask.unsqueeze(1)
      x = self.Attention(parallel_queries, parallel_keys, parallel_values, mask)
    else:
      x = self.Attention(parallel_queries, parallel_keys, parallel_values)
      
    x = x.transpose(1,2)
    x = x.reshape(batch_size, -1, self.n_heads * self.d_key)

    return self.W_O(x)

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, attention, d_FFN=2048, dropout=0.1):
    super(EncoderLayer, self).__init__()
    self.d_model = d_model
    self.MultiHeadAttention = attention
    self.d_FFN = d_FFN
    self.dropout = dropout
  
  def AttentionLayer(self, X):
    attention = self.MultiHeadAttention(X, X, X)
    attention = nn.Dropout(self.dropout)(attention)
    X = X + attention
    return nn.LayerNorm(self.d_model)(X)
  
  def FeedForwardLayer(self, X):
    output = nn.Linear(self.d_model, self.d_FFN)(X)
    output = nn.ReLU()(output)
    output = nn.Linear(self.d_FFN, self.d_model)(output)
    output = nn.Dropout(self.dropout)(output)
    X = X + output
    return nn.LayerNorm(self.d_model)(X)
 
  def forward(self, X):
    X = self.AttentionLayer(X)
    X = self.FeedForwardLayer(X)
    return X

In [None]:
class Encoder(nn.Module):
  def __init__(self, layer, N=6):
    super(Encoder, self).__init__()
    self.encoder_layer = layer
    self.encoder_layers = nn.ModuleList([copy.deepcopy(self.encoder_layer) for _ in range(N)])
  
  def forward(self, X):
    for encoder_layer in self.encoder_layers:
      X = encoder_layer(X)
    return X

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, attention, d_FFN=2048, dropout=0.1):
    super(DecoderLayer, self).__init__()
    self.d_model = d_model
    self.MultiHeadAttention = attention
    self.d_FFN = d_FFN
    self.dropout = dropout

  def MaskedAttentionLayer(self, X):
    d_stc = X.shape[1]
    mask = torch.triu(torch.ones([1,d_stc,d_stc]), diagonal=1)
    masked_attention = self.MultiHeadAttention(X,X,X,mask)
    masked_attention = nn.Dropout(self.dropout)(masked_attention)
    X = X + masked_attention
    return nn.LayerNorm(self.d_model)(X)
  
  def AttentionLayer(self, X, M):
    attention = self.MultiHeadAttention(X, M, M)
    attention = nn.Dropout(self.dropout)(attention)
    X = X + attention
    return nn.LayerNorm(self.d_model)(X)

  def FeedForwardLayer(self, X):
    output = nn.Linear(self.d_model, self.d_FFN)(X)
    output = nn.ReLU()(output)
    output = nn.Linear(self.d_FFN, self.d_model)(output)
    output = nn.Dropout(self.dropout)(output)
    X = X + output
    return nn.LayerNorm(self.d_model)(X)
  
  def forward(self, X, M):
    X = self.MaskedAttentionLayer(X)
    X = self.AttentionLayer(X,M)
    X = self.FeedForwardLayer(X)
    return X

In [None]:
class Decoder(nn.Module):
  def __init__(self, layer, N=6):
    super(Decoder, self).__init__()
    self.decoder_layer = layer
    self.decoder_layers = nn.ModuleList([copy.deepcopy(self.decoder_layer) for _ in range(N)])

  def forward(self, X, M):
    for decoder_layer in self.decoder_layers:
      X = decoder_layer(X, M)
    return X

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model):
    super(PositionalEncoding, self).__init__()
    self.d_model = d_model

  def forward(self, X):
    d_sentence = X.shape[-2]
    
    i = torch.arange(self.d_model).unsqueeze(0)
    pos = torch.arange(d_sentence).unsqueeze(1)
    
    angles = 1/torch.pow(10000, (2*(i//2))/self.d_model)
    angles = pos*angles
    
    angles[:, 0::2] = torch.sin(angles[:, 0::2])
    angles[:, 1::2] = torch.cos(angles[:, 1::2])

    pos_encoding = angles.unsqueeze(0)
    X = X + pos_encoding.requires_grad(False)
    return X

In [None]:
class Embeddings(nn.Module):
  def __init__(self, d_model, vocab_size):
    super(Embeddings,self).__init__()
    self.convert = nn.Embedding(vocab_size, d_model)
    self.d_model = d_model

  def forward(self, x):
    return self.convert(x)*math.sqrt(self.d_model)

In [None]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return nn.LogSoftmax(self.proj(x), dim=-1)

In [None]:
class Transformer(nn.Module):
  def __init__(self, d_model, input_vocab, output_vocab):
    super(Transformer, self).__init__()
    self.d_model = d_model
    self.input_vocab = input_vocab
    self.output_vocab = output_vocab
    self.multiheadAttention = MultiHeadAttention(self.d_model)
    self.encoder = Encoder(EncoderLayer(self.d_model, self.multiheadAttention))
    self.decoder = Decoder(DecoderLayer(self.d_model, self.multiheadAttention))
    self.pos_encod = PositionalEncoding(self.d_model)
    self.input_embeddings = Embeddings(self.d_model, self.input_vocab)
    self.output_embeddings = Embeddings(self.d_model, self.output_vocab)
    self.generator = Generator(self.d_model, self.output_vocab)

  def forward(self, source, target):
    X = self.input_embeddings(source)
    X = self.pos_encod(X)
    M = self.encoder(X)
    
    Y = self.output_embeddings(target)
    Y = self.pos_encod(Y)
    Y = self.decoder(Y, M)

    return self.generator(Y)