<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://www.mihaileric.com/posts/transformers-attention-in-disguise/
# https://jalammar.github.io/illustrated-transformer/
# http://nlp.seas.harvard.edu/2018/04/03/attention.html

# position embedding has a "vocabulary" size of 100, model can accept sentences up to 100 tokens long
# we use a learned positional encoding instead of a static one
# we use the standard Adam optimizer with a static learning rate instead of one with warm-up and cool-down steps
# we do not use label smoothing


In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
%pip install portalocker
%pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm


In [1]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        tgt_batch.append(en_transform(tgt_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TGT_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TGT_LANGUAGE:en_transform}




In [14]:
# @title 6att down
import torch
import torch.nn as nn
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import math
class PositionalEncoder(nn.Module):
    def __init__(self, emb_dim, max_seq_length=512, dropout=0.1):
        super(PositionalEncoder, self).__init__()
        self.emb_dim = emb_dim
        self.drop = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_length, emb_dim)
        pos = torch.arange(0, max_seq_length).unsqueeze(1) # https://nlp.seas.harvard.edu/annotated-transformer/
        div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x*math.sqrt(self.emb_dim) # ?
        # return self.drop(x + Variable(self.pe[:, :x.size(1)], requires_grad=False).to(x.device))
        x = x + self.pe[:, : x.size(1)].requires_grad_(False) # batch_first
        return self.drop(x)


class MHA(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.q = nn.Linear(hid_dim, hid_dim)
        self.k = nn.Linear(hid_dim, hid_dim)
        self.v = nn.Linear(hid_dim, hid_dim)
        self.out = nn.Linear(hid_dim, hid_dim)
        self.drop = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor((self.head_dim,), dtype=torch.float, device=device))

    def forward(self, query, key, value, mask = None):
        batch_size = query.shape[0]
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        Q = self.q(query) #Q = [batch size, query len, hid dim]
        K = self.k(key) #K = [batch size, key len, hid dim]
        V = self.v(value) #V = [batch size, value len, hid dim]
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #Q = [batch size, n heads, query len, head dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #K = [batch size, n heads, key len, head dim]
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #V = [batch size, n heads, value len, head dim]
        # scaled dot-product attention
        energy = torch.matmul(Q, K.transpose(2,3)) / self.scale #energy = [batch size, n heads, query len, key len]
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim = -1) #attention = [batch size, n heads, query len, key len]
        # why dropout applied directly to the attn?
        x = torch.matmul(self.drop(attention), V) #x = [batch size, n heads, query len, head dim]
        x = x.transpose(1, 2).contiguous() #x = [batch size, query len, n heads, head dim]
        x = x.view(batch_size, -1, self.hid_dim) #x = [batch size, query len, hid dim]
        x = self.out(x) #x = [batch size, query len, hid dim]
        return x, attention

class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, ff_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(hid_dim)
        self.norm2 = nn.LayerNorm(hid_dim)
        self.self_attn = MHA(hid_dim, n_heads, dropout)
        self.ff = nn.Sequential(
            nn.Linear(hid_dim, ff_dim), nn.GELU(), # ReLU
            nn.Dropout(dropout), nn.Linear(ff_dim, hid_dim)
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, src, src_mask): #src = [batch size, src len, hid dim] #src_mask = [batch size, 1, 1, src len]
        src = self.norm1(src + self.drop(self.self_attn(src, src, src, src_mask)[0])) #src = [batch size, src len, hid dim]
        src = self.norm2(src + self.drop(self.ff(src))) #src = [batch size, src len, hid dim]
        return src

class Encoder(nn.Module):
    def __init__(self, hid_dim, n_layers, n_heads, ff_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, src, src_mask):
        for layer in self.layers:
            src = layer(src, src_mask) #src = [batch size, src len, hid dim]
        return src

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, ff_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(hid_dim)
        self.norm2 = nn.LayerNorm(hid_dim)
        self.norm3 = nn.LayerNorm(hid_dim)
        self.self_attn = MHA(hid_dim, n_heads, dropout)
        self.enc_attn = MHA(hid_dim, n_heads, dropout)
        self.ff = nn.Sequential(
            nn.Linear(hid_dim, ff_dim), nn.GELU(), # ReLU
            nn.Dropout(dropout), nn.Linear(ff_dim, hid_dim)
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        #enc_src = [batch size, src len, hid dim] #src_mask = [batch size, 1, 1, src len]
        #trg = [batch size, trg len, hid dim] #trg_mask = [batch size, 1, trg len, trg len]
        trg = self.norm1(trg + self.drop(self.self_attn(trg, trg, trg, trg_mask)[0])) #trg = [batch size, trg len, hid dim]
        trg = self.norm2(trg + self.drop(self.enc_attn(trg, enc_src, enc_src, src_mask)[0])) #trg = [batch size, trg len, hid dim]
        trg = self.norm3(trg + self.drop(self.ff(trg))) # og 6att
        return trg #trg = [batch size, trg len, hid dim]

class Decoder(nn.Module):
    def __init__(self, hid_dim, n_layers, n_heads, ff_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, trg, enc_src, trg_mask, src_mask):
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        return trg

class Seq2Seq(nn.Module):
    def __init__(self, in_dim, out_dim, d_model = 512, nhead = 8, enc_layers = 3, dec_layers = 3, ff_dim = 512, dropout = 0.1):
        super().__init__()
        self.encoder = Encoder(d_model, num_encoder_layers, nhead, ff_dim, dropout)
        self.decoder = Decoder(d_model, num_decoder_layers, nhead, ff_dim, dropout)
        self.pos_enc = PositionalEncoder(d_model, dropout=dropout)
        self.src_tok_emb = nn.Embedding(in_dim, d_model)
        self.trg_tok_emb = nn.Embedding(out_dim, d_model)
        self.d_model = d_model
        self.fc_out = nn.Linear(d_model, out_dim)

        for p in self.parameters(): # must be at the end of __init__
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    # src_mask, same shape as the source sentence,
    # value of 1 when the token in the source sentence is not <pad> token,
    # for encoder layers to mask the multi-head attention mechanisms,
    # calculate and apply attention over the source sentence, so dun pay attention to <pad> tokens
    def make_src_mask(self, src): #src = [batch size, src len]
        return (src != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device) #src_mask = [batch size, 1, 1, src len]

    # subsequent mask, tril, nth in tgt can only see up to nth in out
    # bitwise & with trg_pad_mask: dun pay attn to <pad>
    def make_trg_mask(self, trg): #trg = [batch size, trg len]
        trg_pad_mask = (trg != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device) #trg_pad_mask = [batch size, 1, 1, trg len]
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = device)).bool() #trg_sub_mask = [trg len, trg len]
        trg_mask = trg_pad_mask & trg_sub_mask #trg_mask = [batch size, 1, trg len, trg len]
        return trg_mask

    def forward(self, src, trg): #src = [batch size, src len] #trg = [batch size, trg len]
        src_mask = self.make_src_mask(src) #src_mask = [batch size, 1, 1, src len]
        trg_mask = self.make_trg_mask(trg) #trg_mask = [batch size, 1, trg len, trg len]

        src = self.pos_enc(self.src_tok_emb(src.long()) * math.sqrt(self.d_model))
        trg = self.pos_enc(self.trg_tok_emb(trg.long()) * math.sqrt(self.d_model))

        enc_src = self.encoder(src, src_mask) #enc_src = [batch size, src len, hid dim]
        trg = self.decoder(trg, enc_src, trg_mask, src_mask) #output = [batch size, trg len, hid dim]
        trg = self.fc_out(trg) #output = [batch size, trg len, output dim]
        return trg

    def encode(self, src, src_mask):
        return self.encoder(self.pos_enc(self.src_tok_emb(src.long()) * math.sqrt(self.d_model)), src_mask)

    def decode(self, trg, memory, trg_mask, src_mask):
        trg = self.decoder(self.pos_enc(self.trg_tok_emb(trg.long()) * math.sqrt(self.d_model)), memory, trg_mask, src_mask)
        return self.fc_out(trg)

in_dim = len(de_vocab)
out_dim = len(en_vocab)
d_model=512 # hid_dim
nhead=8
num_encoder_layers=3
num_decoder_layers=3
dim_feedforward=512 # pf_dim
dropout=0.1

model = Seq2Seq(in_dim, out_dim, d_model = 512, nhead = 8, enc_layers = 3, dec_layers = 3, ff_dim = 512, dropout = 0.1).to(device)
# torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation=<function relu>, custom_encoder=None, custom_decoder=None, layer_norm_eps=1e-05, batch_first=False, norm_first=False, bias=True, device=None, dtype=None)
# torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1)


In [10]:
# @title train eval

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    epoch_loss = 0
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
        optimizer.zero_grad()
        output = model(src, trg[:,:-1]) #output = [batch size, trg len - 1, output dim]
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim) #output = [batch size * trg len - 1, output dim]
        trg = trg[:,1:].contiguous().view(-1) #trg = [batch size * trg len - 1]
        loss = loss_fn(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1) # from og 6attsalluneed
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))

def evaluate(model, dataloader, loss_fn):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
            output = model(src, trg[:,:-1]) #output = [batch size, trg len - 1, output dim]
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim) #output = [batch size * trg len - 1, output dim]
            trg = trg[:,1:].contiguous().view(-1) #trg = [batch size * trg len - 1]
            loss = loss_fn(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))



In [17]:
# @title translate

def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(1,-1).to(device)
    num_tokens = src.shape[0]
    src_mask = torch.zeros((num_tokens, num_tokens), dtype=bool, device=device)
    with torch.no_grad():
        enc_src = model.encode(src, src_mask)
    trg_indexes = [BOS_IDX]
    max_len = src.shape[1]+5
    for i in range(max_len):
        trg_tensor = torch.tensor(trg_indexes, dtype=torch.long, device=device).unsqueeze(0)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output = model.decode(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    tgt_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
    return " ".join(en_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy())))

# UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))


A group of people are standing in front of a group of people .


In [16]:
# @title run
# import math
import time
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

for epoch in range(20):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn)
    val_loss = evaluate(model, val_loader, loss_fn)
    end_time = time.time()
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))
#  A group of people standing in front of an igloo

# sine pos enc, scale after pos
# Epoch: 20, Train loss: 3.238, Val loss: 3.504, Epoch time = 44.329s
# A person is playing a trick in a race .

# sine pos enc, token,scale,pos
# Epoch: 20, Train loss: 1.506, Val loss: 2.149, Epoch time = 43.637s
# A crowd of people of a crowd .

# scale, gelu, sine pos enc
# Epoch: 20, Train loss: 1.547, Val loss: 2.114, Epoch time = 41.686s
# A group of people are standing in a doorway .


Epoch: 1, Train loss: 5.319, Val loss: 4.194, Epoch time = 40.968s
A group of people are in a group of a group of a group of
Epoch: 2, Train loss: 3.955, Val loss: 3.655, Epoch time = 43.725s
A group of people are standing on a group of people .
Epoch: 3, Train loss: 3.518, Val loss: 3.366, Epoch time = 42.490s
A group of people are standing in front of a crowd .
Epoch: 4, Train loss: 3.218, Val loss: 3.157, Epoch time = 42.668s
A group of people are standing in front of a crowd .
Epoch: 5, Train loss: 2.981, Val loss: 2.984, Epoch time = 43.254s
A group of people are standing in front of a group of people .
Epoch: 6, Train loss: 2.782, Val loss: 2.842, Epoch time = 42.966s
A group of people are standing in front of a group of people .
Epoch: 7, Train loss: 2.610, Val loss: 2.740, Epoch time = 42.803s
A group of people are standing in front of a group of people .
Epoch: 8, Train loss: 2.451, Val loss: 2.637, Epoch time = 43.211s
A group of people are standing in front of a group of peo

KeyboardInterrupt: 