<a href="https://colab.research.google.com/github/namanraiyani/TransformerFromScratch/blob/main/TransformerFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import math
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

Input Embedding

In [8]:
class InputEmbeddings(nn.Module):
  def __init__(self, embedding_dim, vocab_size):   # embedding_dim is d_model
    super().__init__()
    self.embedding_dim = embedding_dim
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, embedding_dim)

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.embedding_dim) # scale embeddings to match positional encoding scale

Positional Encoding

In [9]:
class PositionalEncoding(nn.Module):
  def __init__(self, embedding_dim, sequence_len, dropout):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.sequence_len = sequence_len
    self.dropout = nn.Dropout(dropout)

    PE = torch.zeros(sequence_len, embedding_dim)
    position = torch.arange(0, sequence_len, dtype = torch.float)
    positition = position.unsqueeze(1)

    denominator_term = torch.exp(torch.arange(0, embedding_dim, step = 2).float() * (-math.log(10000.0) / embedding_dim))

    PE[:, 0::2] = torch.sin(position * denominator_term)
    PE[:, 1::2] = torch.cos(position * denominator_term)
    PE = PE.unsqueeze(0)

    self.register_buffer('PE', PE)

  def forward(self, x):
    x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)
    return self.dropout(x)

MultiHead Attention

In [10]:
class MultiHeadAttentionBlock(nn.Module):
  def __init__(self, embedding_dim, h, dropout):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.h = h

    assert embedding_dim % h == 0, "embedding_dim is not divisible by h"

    self.d_k = embedding_dim // h
    self.w_q = nn.Linear(embedding_dim, embedding_dim)
    self.w_k = nn.Linear(embedding_dim, embedding_dim)
    self.w_v = nn.Linear(embedding_dim, embedding_dim)
    self.w_o = nn.Linear(embedding_dim, embedding_dim)

    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(query, key, value, mask, dropout):
    d_k = query.shape[-1]
    attention_score = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
      attention_score.masked_fill_(mask == 0, -1e9)
    attention_score = attention_score.softmax(dim = -1)
    if dropout is not None:
      attention_score = dropout(attention_score)
    return (attention_score @ value), attention_score

  def forward(self, q, k, v, mask):
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)

    x, self.attention_score = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
    x = x.transpose(2,1).contiguous().view(x.shape[0], -1, self.h * self.d_k)
    return self.w_o(x)

Layer Normalization

In [11]:
class LayerNormalization(nn.Module):
  def __init__(self, eps = 10**-6):
    super().__init__()
    self.eps = eps

    self.gamma = nn.Parameter(torch.ones(1))
    self.beta = nn.Parameter(torch.zeros(1))

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    std = x.std(dim=-1, keepdim=True)
    return (self.gamma * ((x - mean) / (std + self.eps))) * self.beta

FeedForward Network

In [12]:
class FeedForwardBlock(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff)
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))