## Building a Transformer from scratch
Here we go....

In [29]:
import torch 
import torch.nn as nn
import math
import warnings
from torch.utils.data import random_split, DataLoader, Dataset
# from config import get_config, get_weights_file_path
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from tqdm.auto import tqdm
# from dataset_prep import BilingualDataset, causal_mask
# from mini_transformer import build_transformer


In [30]:

class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model 
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x):
        embedding = self.embedding(x) * math.sqrt(self.d_model)
        return embedding
    

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len, dropout):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
#         x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) 
        
        return self.dropout(x)
    

In [31]:

class LayerNormalization(nn.Module):
    def __init__(self, epsilon: float = 10**-6):
        super().__init__()
        self.epsilon = epsilon
        self.alpha = nn.Parameter(torch.ones(1)) # MUltiplied
        self.bias = nn.Parameter(torch.zeros(1)) # Addition
        
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)
        
        return self.alpha * (x - mean) / (std + self.epsilon) + self.bias


class FeedForwardLayer(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(torch.relu(x))
        x = self.linear2(x)
        
        return x


In [32]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nh, dropout) -> None:
        super().__init__()
        self.d_model = d_model
        self.nh = nh
        assert d_model % nh == 0, 'd_model is not divisble by nh'
        self.d_k = d_model // nh

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
            
        attention_scores = attention_scores.softmax(dim = -1)
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        return (attention_scores @ value), attention_scores
            
        
    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        query = query.view(query.shape[0], query.shape[1], self.nh, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.nh, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.nh, self.d_k).transpose(1, 2)

        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
        x =  x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.nh * self.d_k)
        
        return self.w_o(x)


class ResidualConnection(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self, x, sublayer):        
        return x + self.dropout(sublayer(self.norm(x)))


In [33]:

class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttention, feed_forward: FeedForwardLayer, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_layer = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
        
    def forward(self, x, source_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, source_mask))
        x = self.residual_connections[1](x, self.feed_forward_layer)
        
        return x

class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
            
        return self.norm(x)


class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward: FeedForwardLayer, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
        
    def forward(self, x, encoder_output, source_mask, target_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, target_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, source_mask))
        x = self.residual_connections[2](x, self.feed_forward)
        return x


class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
        
    def forward(self, x, encoder_output, source_mask, target_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, source_mask, target_mask)
            
            return self.norm(x)

        
class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.project = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        return torch.log_softmax(self.project(x), dim=-1)


In [34]:
# Transformer class

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, source_embed: InputEmbedding, target_embed: InputEmbedding, source_pos: PositionalEncoding, target_pos: PositionalEncoding, projection: ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.source_embed = source_embed
        self.target_embed = target_embed
        self.source_pos = source_pos
        self.target_pos = target_pos
        self.projection = projection
        
    def encode(self, source, source_mask):
        source = self.source_embed(source)
        source = self.source_pos(source)
        
        return self.encoder(source, source_mask)
    
    def decode(self, encoder_output, source_mask, target, target_mask):
        target = self.target_embed(target)
        target = self.target_pos(target)
        
        return self.decoder(target, encoder_output, source_mask, target_mask)
    
    def project(self, x):
        return self.projection(x)



In [35]:
#  N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048
def build_transformer(source_vocab_size, target_vocab_size, source_seq_len, target_seq_len, d_model, nh=8, n_blocks=6, d_ff=2048, dropout=0.1):
    source_embed = InputEmbedding(d_model, source_vocab_size)
    target_embed = InputEmbedding(d_model, target_vocab_size)

    source_pos = PositionalEncoding(d_model, source_seq_len, dropout)
    target_pos = PositionalEncoding(d_model,  target_seq_len, dropout)

    encoder_blocks = []
    for _ in range(n_blocks):
        encoder_self_attention_block = MultiHeadAttention(d_model, nh, dropout)
        feed_forward_block = FeedForwardLayer(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    decoder_blocks = []
    for _ in range(n_blocks):
        decoder_cross_attention_block = MultiHeadAttention(d_model, nh, dropout)
        decoder_self_attention_block = MultiHeadAttention(d_model, nh, dropout)
        feed_forward_block = FeedForwardLayer(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(
            decoder_self_attention_block, decoder_cross_attention_block,
            feed_forward_block,
            dropout,
        )

        decoder_blocks.append(decoder_block)

    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    
    projection_layer = ProjectionLayer(d_model, target_vocab_size)
    
    transformer = Transformer(encoder, decoder, source_embed, target_embed, source_pos, target_pos, projection_layer)
    
    for x in transformer.parameters():
        if x.dim() > 1:
            nn.init.xavier_uniform_(x)
    
    return transformer


In [36]:
# Dataset preparation
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        target_text = src_target_pair['translation'][self.tgt_lang]

        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(target_text).ids

        enc_num_pad_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_pad_tokens = self.seq_len - len(dec_input_tokens) - 1

        if enc_num_pad_tokens < 0 or dec_num_pad_tokens < 0:
            raise ValueError('Sentence is too long')

        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token, 
                torch.tensor([self.pad_token] * enc_num_pad_tokens, dtype=torch.int64)
            ], dim=0,
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_pad_tokens, dtype=torch.int64),
            ], dim=0,
        )
        
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_pad_tokens, dtype=torch.int64)
            ], dim=0,
        )
        
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input,
            'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            'label': label,
            'source_text': src_text,
            'target_text': target_text
        }
        
        
def causal_mask(size):
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

In [37]:
# Config


def get_config():
    return {
        "batch_size": 8,
        "lr": 10**-4,
        "epochs": 25,
        "seq_len": 350,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "es",
        "model_folder": "weights",
        "model_file": "mini_transformer_",
        "preload": None,
        "tokenizer_file": "tokenizer_{0}.json",
        "exp_name": "runs/mini_transformer",
    }

def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_file = config['model_file']
    model_filename = f'{model_file}{epoch}.pt'
    return str(Path('.')/model_folder/model_filename)



In [38]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]
           
def create_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=['[UNK]', '[PAD]', '[SOS]', '[EOS]'], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
        
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
    return tokenizer

In [42]:
def get_dataset(config):
#     ds_raw = load_dataset('opus_books', f'{config['lang_src'] + config['lang_tgt']}', split='train')
    ds_raw = load_dataset('opus_books', config['lang_src'] + '-' + config['lang_tgt'], split='train')
    tokenizer_src = create_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = create_tokenizer(config, ds_raw, config['lang_tgt'])
    
    train_size = int(0.9 * len(ds_raw))
    val_size = len(ds_raw) - train_size
    train_data, val_data = random_split(ds_raw, [train_size, val_size])
    
    train_dataset = BilingualDataset(train_data, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    valid_dataset = BilingualDataset(val_data, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    
    max_len_src = 0
    max_len_tgt = 0
    
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(item['translation'][config['lang_tgt']]).ids
        
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
        
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(valid_dataset, batch_size=1, shuffle=True)
    
    return train_loader, val_loader, tokenizer_src, tokenizer_tgt 



In [43]:
# Training loop
def build_model(config, vocab_src_len, vocab_tgt_len):
    mini_transformer = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], d_model=config['d_model'])
    
    return mini_transformer


def training_loop(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using {device} device')
    
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
    
    train_loader, val_loader, tokenizer_src, tokenizer_tgt = get_dataset(config)
    model = build_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    
    # Tensorboard
    writer = SummaryWriter(config['exp_name'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    initial_epoch = 0
    step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Loading weights from {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer'])
        step = state['step']
        
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
    
    for epoch in tqdm(range(initial_epoch, config['epochs'])):
        model.train()
        batch_iterator = tqdm(train_loader, desc=f'Epoch {epoch}')
        
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            projection_output = model.projection(decoder_output)
            
            label = batch['label'].to(device)
            loss = loss_fn(projection_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({f'loss': f'{loss.item():.2f}'})
            
            writer.add_scalar('loss', loss.item(), step)
            writer.flush()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            step += 1
            
            model_filename = get_weights_file_path(config, f"{epoch:02d}")
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'step': step,
                'optimizer': optimizer.state_dict()
            }, model_filename)

In [None]:
warnings.filterwarnings("ignore")
config = get_config()
training_loop(config)

Using cuda device
Max length of source sentence: 309
Max length of target sentence: 274


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/3638 [00:00<?, ?it/s]