<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/translation_transformer_down.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
%pip install portalocker
%pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm


In [None]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        tgt_batch.append(en_transform(tgt_sample.rstrip("\n")))
    # src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    # tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TGT_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TGT_LANGUAGE:en_transform}


In [3]:
# @title model
import torch
import torch.nn as nn
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1) # .reshape(-1, 1)
        pos_emb = torch.zeros((maxlen, emb_size))
        pos_emb[:, 0::2] = torch.sin(pos * den) # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_emb[:, 1::2] = torch.cos(pos * den) # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_emb = pos_emb.unsqueeze(0) # batch_first=F -> unsqueeze(-2)
        self.register_buffer('pos_emb', pos_emb) # register as buffer so optimizer wont update it

    def forward(self, token_emb):
        return self.dropout(token_emb + self.pos_emb[:, :token_emb.size(1)]) # batch_first=F -> [:token_emb.size(0), :]


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.d_model = d_model
        # self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        # self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)
        self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.generator = nn.Linear(d_model, tgt_vocab_size)

        self.src_tok_emb = TokenEmbedding(src_vocab_size, d_model)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, d_model)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    # def forward(self, src, tgt, src_mask, tgt_mask):
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        src_emb = self.pos_enc(self.src_tok_emb(src))
        tgt_emb = self.pos_enc(self.tgt_tok_emb(tgt))

        # out = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        out = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        return self.generator(out)

    def encode(self, src, src_mask=None):
        return self.transformer.encoder(self.pos_enc(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None):
        return self.transformer.decoder(self.pos_enc(self.tgt_tok_emb(tgt)), memory, tgt_mask, memory_mask)


src_vocab_size = len(de_vocab)
tgt_vocab_size = len(en_vocab)

emb_size = 512 # d_model
nhead = 8
dim_feedforward = 512
num_encoder_layers = 3
num_decoder_layers = 3

model = Transformer(src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1).to(device)


In [6]:
# @title mask translate

# subsequent word mask that will prevent the model from looking into the future words when making predictions.
# also need masks to hide source and target padding token
def generate_square_subsequent_mask(sz):
    mask = torch.tril(torch.ones((sz, sz), device=device))#.bool()
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[1] # batch_first=F -> [0]
    tgt_seq_len = tgt.shape[1] # batch_first=F -> [0]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)
    src_padding_mask = (src == PAD_IDX) # batch_first=F -> .transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX) # batch_first=F -> .transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(1,-1).to(device)
    num_tokens = src.shape[1] # batch_first=F -> [0]
    src_mask = torch.zeros((num_tokens, num_tokens), dtype=bool, device=device)
    trg_indexes = [BOS_IDX]
    max_len = src.shape[1]+5
    for i in range(max_len):
        trg_tensor = torch.tensor(trg_indexes, dtype=torch.long, device=device).unsqueeze(0)
        trg_mask = generate_square_subsequent_mask(trg_tensor.size(1))
        with torch.no_grad():
            output = model(src, trg_tensor, src_mask, trg_mask)
        pred_token = output.argmax(2)[:,-1].item() # batch_first=F -> ?
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    tgt_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
    return " ".join(en_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy())))

# UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))


Russia cloth spoof Russia sewing Madrid Madrid Russia silhouetted Madrid Russia Madrid Madrid Russia cloth


In [8]:
# @title train test

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    losses = 0
    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:, :-1] # batch_first=F -> [:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask)
        # logits = model(src, tgt_input, src_mask, tgt_mask)
        optimizer.zero_grad()
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1)) # batch_first=F -> tgt[1:, :]
        loss.backward()
        optimizer.step()
        losses += loss.item()
    return losses / len(list(dataloader))

def test(dataloader, model, loss_fn):
    model.eval()
    losses = 0
    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:, :-1] # batch_first=F -> [:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        logits = model(src, tgt_input, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask)
        tgt_out = tgt[:, 1:] # batch_first=F -> [1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(list(dataloader))


In [9]:
# @title run

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

import time
for epoch in range(18):
    start_time = time.time()
    train_loss = train(train_loader, model, loss_fn, optimizer)
    end_time = time.time()
    val_loss = test(val_loader, model, loss_fn)
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))




Epoch: 1, Train loss: 5.339, Val loss: 4.117, Epoch time = 43.515s
A group of people are playing in a red .
Epoch: 2, Train loss: 3.791, Val loss: 3.371, Epoch time = 44.869s
A group of people are standing in front of a crowd of people .
Epoch: 3, Train loss: 3.205, Val loss: 2.948, Epoch time = 43.700s
A group of people standing in front of a building .
Epoch: 4, Train loss: 2.807, Val loss: 2.674, Epoch time = 45.057s
A group of people standing in front of a crowd of people stand .
Epoch: 5, Train loss: 2.513, Val loss: 2.467, Epoch time = 44.487s
A group of people standing in front of a store .
Epoch: 6, Train loss: 2.274, Val loss: 2.322, Epoch time = 45.964s
A group of people standing in front of an orange instrument in front of an
Epoch: 7, Train loss: 2.076, Val loss: 2.226, Epoch time = 45.011s
A group of people standing in front of an orange instrument in front of an
Epoch: 8, Train loss: 1.908, Val loss: 2.127, Epoch time = 45.747s
A group of people standing in front of an or

In [10]:
# @title inference
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu .")) # A group of people stand in front of an igloo .
print(translate(model, "Ein Koch in weißer Uniform bereitet Essen in einer Restaurantküche zu .")) # A chef in a white uniform prepares food in a restaurant kitchen .
print(translate(model, "Zwei junge Mädchen spielen Fußball auf einem Feld. .")) # Two young girls play soccer on a field. .
print(translate(model, "Eine Frau mit Hut und Sonnenbrille steht am Strand .")) # A woman wearing a hat and sunglasses stands on the beach .
print(translate(model, "Zwei Freunde lachen und genießen ein Eis auf einer wunderschönen Wiese .")) # Two friends laugh and enjoy ice cream on a beautiful meadow .


A group of people standing in front an auditorium in an igloo .
A chef in a white uniform preparing food in a restaurant kitchen kitchen .
Two young girls playing soccer on a field in a field .
A woman wearing a hat and sunglasses stands on the beach at the beach .
Two friends smile and enjoying ice cream on a beautiful field of a beautiful field .
