<a href="https://colab.research.google.com/github/eisbetterthanpi/transformer/blob/main/transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://www.mihaileric.com/posts/transformers-attention-in-disguise/
# https://jalammar.github.io/illustrated-transformer/
# http://nlp.seas.harvard.edu/2018/04/03/attention.html

# position embedding <-> "vocabulary" size 100 <-> model can accept sentences up to 100 tokens long
# learned positional encoding, warm-up and cool-down steps, label smoothing

In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
%pip install portalocker
%pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm


In [None]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TRG_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, trg_batch = [], []
    for src_sample, trg_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        trg_batch.append(en_transform(trg_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    trg_batch = pad_sequence(trg_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, trg_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TRG_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TRG_LANGUAGE:en_transform}


In [None]:
# @title model
import torch
import torch.nn as nn
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class PositionalEncoder(nn.Module):
    def __init__(self, emb_dim, max_seq_length=512, dropout=0.1):
        super(PositionalEncoder, self).__init__()
        self.emb_dim = emb_dim
        self.drop = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_length, emb_dim)
        pos = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.drop(x + self.pe[:, : x.size(1)])


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.lin = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor((self.head_dim,), dtype=torch.float, device=device))

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.q(query).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k(key).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v(value).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        # attn = torch.matmul(Q, K.transpose(2, 3)) / self.scale
        attn = Q @ K.transpose(2, 3) / self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(attn, dim=-1)
        # x = torch.matmul(self.drop(attention), V)
        x = self.drop(attention) @ V
        x = x.transpose(1, 2).reshape(batch_size, -1, self.d_model)
        x = self.lin(x)
        return x, attention


class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim), nn.ReLU(), # ReLU GELU
            nn.Dropout(dropout), nn.Linear(ff_dim, d_model)
        )

    def forward(self, src, src_mask):
        src = self.norm1(src + self.drop(self.self_attn(src, src, src, src_mask)[0]))
        src = self.norm2(src + self.drop(self.ff(src)))
        return src

class Encoder(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, ff_dim, dropout=0.1):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, src, src_mask):
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.enc_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim), nn.ReLU(), # ReLU GELU
            nn.Dropout(dropout), nn.Linear(ff_dim, d_model)
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.norm1(trg + self.drop(self.self_attn(trg, trg, trg, trg_mask)[0]))
        trg = self.norm2(trg + self.drop(self.enc_attn(trg, enc_src, enc_src, src_mask)[0]))
        trg = self.norm3(trg + self.drop(self.ff(trg)))
        return trg

class Decoder(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, ff_dim, dropout=0.1):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, trg, enc_src, trg_mask, src_mask):
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        return trg

class Seq2Seq(nn.Module):
    def __init__(self, in_dim, out_dim, d_model=512, nhead=8, enc_layers=3, dec_layers=3, ff_dim=512, dropout=0.1):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(d_model, enc_layers, nhead, ff_dim, dropout)
        self.decoder = Decoder(d_model, dec_layers, nhead, ff_dim, dropout)
        self.pos_enc = PositionalEncoder(d_model, dropout=dropout)
        self.src_tok_emb = nn.Embedding(in_dim, d_model)
        self.trg_tok_emb = nn.Embedding(out_dim, d_model)
        self.d_model = d_model
        self.lin = nn.Linear(d_model, out_dim)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, trg, src_mask=None, trg_mask=None):
        src = self.pos_enc(self.src_tok_emb(src) * math.sqrt(self.d_model))
        trg = self.pos_enc(self.trg_tok_emb(trg) * math.sqrt(self.d_model))
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        output = self.lin(output)
        return output

    def encode(self, src, src_mask=None):
        return self.encoder(self.pos_enc(self.src_tok_emb(src) * math.sqrt(self.d_model)), src_mask)

    def decode(self, trg, memory, trg_mask=None, src_mask=None):
        trg = self.decoder(self.pos_enc(self.trg_tok_emb(trg) * math.sqrt(self.d_model)), memory, trg_mask, src_mask)
        return self.lin(trg)


in_dim = len(de_vocab)
out_dim = len(en_vocab)
model = Seq2Seq(in_dim, out_dim, d_model=512, nhead=8, enc_layers=3, dec_layers=3, ff_dim=512, dropout=0.1).to(device)


In [None]:
# @title mask translate

def make_src_mask(src):
    return (src != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device) # [batch_size, 1, src_len]?

def make_trg_mask(trg):
    trg_pad_mask = (trg != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device)
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=device)).bool()
    trg_mask = trg_pad_mask & trg_sub_mask
    return trg_mask

def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(1,-1).to(device)
    num_tokens = src.shape[1]
    trg_indexes = [BOS_IDX]
    max_len = src.shape[1]+5
    for i in range(max_len):
        trg_tensor = torch.tensor(trg_indexes, dtype=torch.long, device=device).unsqueeze(0)
        src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_tensor)
        with torch.no_grad():
            output = model(src, trg_tensor, src_mask, trg_mask)
        pred_token = output.argmax(2)[:,-1].item() # batch_first=F -> ?
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    trg_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
    return " ".join(en_vocab.lookup_tokens(list(trg_tokens.cpu().numpy())))

# UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))


In [None]:
# @title train test

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
        trg_input = trg[:,:-1]
        src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_input)
        output = model(src, trg_input, src_mask, trg_mask) #output = [batch size, trg len - 1, output dim]
        optimizer.zero_grad()
        loss = loss_fn(output.reshape(-1, output.shape[-1]), trg[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(list(dataloader))

def test(model, dataloader, loss_fn):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
            trg_input = trg[:,:-1]
            src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_input)
            output = model(src, trg_input, src_mask, trg_mask) #output = [batch size, trg len - 1, output dim]
            loss = loss_fn(output.reshape(-1, output.shape[-1]), trg[:,1:].reshape(-1))
            epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))


In [None]:
# @title run
import time

loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

for epoch in range(20):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn)
    val_loss = test(model, val_loader, loss_fn)
    end_time = time.time()
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu .")) # A group of people standing in front of an igloo .


Epoch: 1, Train loss: 5.402, Val loss: 4.186, Epoch time = 41.608s
A group of people are are are are are are in a .
Epoch: 2, Train loss: 3.898, Val loss: 3.545, Epoch time = 41.068s
A group of people are standing in a crowd of people .
Epoch: 3, Train loss: 3.353, Val loss: 3.125, Epoch time = 41.566s
A group of people standing in front of a crowd .
Epoch: 4, Train loss: 2.944, Val loss: 2.830, Epoch time = 40.756s
A group of people standing in front of a building .
Epoch: 5, Train loss: 2.630, Val loss: 2.596, Epoch time = 41.468s
A group of people standing in front of a crowd .
Epoch: 6, Train loss: 2.375, Val loss: 2.429, Epoch time = 41.023s
A group of people standing in front of a house .
Epoch: 7, Train loss: 2.166, Val loss: 2.307, Epoch time = 41.604s
A group of people stand in front of a house .
Epoch: 8, Train loss: 1.984, Val loss: 2.210, Epoch time = 40.876s
A group of people stand in front of an audience .
Epoch: 9, Train loss: 1.834, Val loss: 2.131, Epoch time = 41.496s

In [7]:
# @title inference
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu .")) # A group of people stand in front of an igloo .
print(translate(model, "Ein Koch in weißer Uniform bereitet Essen in einer Restaurantküche zu .")) # A chef in a white uniform prepares food in a restaurant kitchen .
print(translate(model, "Zwei junge Mädchen spielen Fußball auf einem Feld. .")) # Two young girls play soccer on a field. .
print(translate(model, "Eine Frau mit Hut und Sonnenbrille steht am Strand .")) # A woman wearing a hat and sunglasses stands on the beach .
print(translate(model, "Zwei Freunde lachen und genießen ein Eis auf einer wunderschönen Wiese .")) # Two friends laugh and enjoy ice cream on a beautiful meadow .


A group of people stand in front of an igloo .
A chef in a white uniform preparing food in a restaurant kitchen .
Two young girls playing soccer on a field .
A woman wearing a hat and sunglasses is standing on the beach .
Two friends laugh and enjoy an ice during a beautiful field .


In [None]:
# @title bleu
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    for datum in data:
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
        #cut off <eos> token
        pred_trg = pred_trg[:-1]
        pred_trgs.append(pred_trg)
        trgs.append([trg])
    return bleu_score(pred_trgs, trgs)
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)
print(f'BLEU score = {bleu_score*100:.2f}')
# 36.52, which beats the ~34 of the convolutional sequence-to-sequence model and ~28 of the attention based RNN model.

def translate_sentence_vectorized(src_tensor, src_field, trg_field, model, device, max_len=50):
    assert isinstance(src_tensor, torch.Tensor)

    model.eval()
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    # enc_src = [batch_sz, src_len, hid_dim]

    trg_indexes = [[trg_field.vocab.stoi[trg_field.init_token]] for _ in range(len(src_tensor))]
    # Even though some examples might have been completed by producing a <eos> token
    # we still need to feed them through the model because other are not yet finished
    # and all examples act as a batch. Once every single sentence prediction encounters
    # <eos> token, then we can stop predicting.
    translations_done = [0] * len(src_tensor)
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_tokens = output.argmax(2)[:,-1]
        for i, pred_token_i in enumerate(pred_tokens):
            trg_indexes[i].append(pred_token_i)
            if pred_token_i == trg_field.vocab.stoi[trg_field.eos_token]:
                translations_done[i] = 1
        if all(translations_done):
            break

    # Iterate through each predicted example one by one;
    # Cut-off the portion including the after the <eos> token
    pred_sentences = []
    for trg_sentence in trg_indexes:
        pred_sentence = []
        for i in range(1, len(trg_sentence)):
            if trg_sentence[i] == trg_field.vocab.stoi[trg_field.eos_token]:
                break
            pred_sentence.append(trg_field.vocab.itos[trg_sentence[i]])
        pred_sentences.append(pred_sentence)
    return pred_sentences, attention

from torchtext.data.metrics import bleu_score

def calculate_bleu_alt(iterator, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    with torch.no_grad():
        for batch in iterator:
            src = batch.src
            trg = batch.trg
            _trgs = []
            for sentence in trg:
                tmp = []
                # Start from the first token which skips the <start> token
                for i in sentence[1:]:
                    # Targets are padded. So stop appending as soon as a padding or eos token is encountered
                    if i == trg_field.vocab.stoi[trg_field.eos_token] or i == trg_field.vocab.stoi[trg_field.pad_token]:
                        break
                    tmp.append(trg_field.vocab.itos[i])
                _trgs.append([tmp])
            trgs += _trgs
            pred_trg, _ = translate_sentence_vectorized(src, src_field, trg_field, model, device)
            pred_trgs += pred_trg
    return pred_trgs, trgs, bleu_score(pred_trgs, trgs)
