In [1]:
import torch
from torchtext.datasets import Multi30k
import spacy
from model import Transformer
from data_utils import generate_mask_src, generate_mask_trg, make_vocab, Specials, TextPadderDataset
from tqdm import tqdm

In [2]:
BATCH_SIZE = 48
SEQ_LEN = 32
EPOCHS = 30
PAD_TOKEN_ID = 1

In [3]:
def create_dataloader(batch_size=32, seq_len=64, device='cpu'):
    train_iter, valid_iter, test_iter = Multi30k('./datasets/Multi30k/', )
    
    en_tokenizer = spacy.load("en_core_web_sm")
    de_tokenizer = spacy.load('de_core_news_sm')
    # Disable all of the unnecessary pipes to accelerate the data pipeline.
    en_tokenizer.disable_pipes(['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer', 'ner'])
    de_tokenizer.disable_pipes(['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer', 'ner'])
    
    de_vocab = make_vocab(
        text_iter=[trg for trg, src in train_iter], 
        tokenizer=de_tokenizer, 
        specials=(Specials.UNK, Specials.PAD, Specials.SOS, Specials.EOS), 
        min_freq=1,
        voc_cache_name='./datasets/Multi30k/de_vocab.pkl'
    )
    en_vocab = make_vocab(
        text_iter=[src for trg, src in train_iter], 
        tokenizer=en_tokenizer, 
        min_freq=1, 
        specials=(Specials.UNK, Specials.PAD), 
        voc_cache_name='./datasets/Multi30k/en_vocab.pkl'
    )
    
    my_dataset = TextPadderDataset(
        src_vocab=en_vocab, trg_vocab=de_vocab,
        text_iterator=train_iter,
        src_tokenizer_fn=lambda line: [token.text for token in en_tokenizer(line.lower())],
        trg_tokenizer_fn=lambda line: [token.text for token in de_tokenizer(line.lower())],
        max_seq_length=seq_len,
        device=device
    )
    
    loader = torch.utils.data.DataLoader(my_dataset, batch_size=batch_size)
    return loader, valid_iter, en_vocab, de_vocab, en_tokenizer, de_tokenizer

In [4]:
dataloader, valid_iter, en_vocab, de_vocab, en_tokenizer, de_tokenizer = create_dataloader(batch_size=BATCH_SIZE, seq_len=SEQ_LEN)
model = Transformer(seq_len=SEQ_LEN, src_vocab_size=len(en_vocab), trg_vocab_size=len(de_vocab), dropout_p=0.1)



In [5]:
_ = model.cuda()

In [6]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4, betas=[0.9, 0.98], eps=1e-9)
loss_obj = torch.nn.NLLLoss(reduction='sum', ignore_index=PAD_TOKEN_ID)

In [7]:
for epoch in range(EPOCHS):
    model.train()
    iterator = tqdm(enumerate(dataloader))
    for i, (trg_seq, src_seq) in iterator:
        optim.zero_grad()

        trg_input_seq = trg_seq[:, :-1]
        trg_label_seq = trg_seq[:, 1:]
        trg_input_seq = trg_input_seq.cuda(); src_seq = src_seq.cuda()
        trg_label_seq = trg_label_seq.cuda()

        src_mask = generate_mask_src(src_seq, PAD_TOKEN_ID)
        trg_mask = generate_mask_trg(trg_input_seq, PAD_TOKEN_ID)

        enc_out, logits = model(src_seq, trg_input_seq, src_mask, trg_mask)
        loss = loss_obj(logits.view(-1, logits.shape[-1]), trg_label_seq.view(-1)) / float(BATCH_SIZE)

        loss.backward()
        optim.step()
        iterator.set_postfix_str(f"Epoch/Iteration {epoch}/{i}. Loss: {loss}")

603it [02:31,  3.98it/s, Epoch/Iteration 0/602. Loss: 24.68532371520996] 
603it [02:15,  4.46it/s, Epoch/Iteration 1/602. Loss: 21.51140594482422] 
603it [02:04,  4.86it/s, Epoch/Iteration 2/602. Loss: 21.921279907226562]
603it [02:02,  4.92it/s, Epoch/Iteration 3/602. Loss: 18.377609252929688]
603it [02:06,  4.76it/s, Epoch/Iteration 4/602. Loss: 17.299686431884766]
603it [02:08,  4.67it/s, Epoch/Iteration 5/602. Loss: 14.54904842376709] 
603it [02:01,  4.97it/s, Epoch/Iteration 6/602. Loss: 13.037145614624023]
603it [02:02,  4.91it/s, Epoch/Iteration 7/602. Loss: 14.913864135742188]
603it [02:03,  4.88it/s, Epoch/Iteration 8/602. Loss: 12.168251037597656]
603it [02:06,  4.76it/s, Epoch/Iteration 9/602. Loss: 11.236723899841309]
603it [02:05,  4.82it/s, Epoch/Iteration 10/602. Loss: 8.170421600341797] 
603it [02:09,  4.65it/s, Epoch/Iteration 11/602. Loss: 9.904566764831543] 
603it [02:07,  4.74it/s, Epoch/Iteration 12/602. Loss: 9.033177375793457] 
603it [02:07,  4.73it/s, Epoch/Iter

In [8]:
valid_iter_ = iter(valid_iter)

In [9]:
trg, src = next(valid_iter_)
trg, src

('Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen',
 'A group of men are loading cotton onto a truck')

In [12]:
from decoding import GreedyDecoder

In [13]:
decoder = GreedyDecoder(
    model=model, 
    tokenizer=lambda sent: [token.text for token in en_tokenizer(sent)], 
    src_vocab=en_vocab, trg_vocab=de_vocab, 
    eos_token_id=3, sos_token_id=2, pad_token_id=PAD_TOKEN_ID, 
    max_seq_length=SEQ_LEN
)

In [14]:
transl = decoder.decode(src)
transl

'eine gruppe von männern beugt sich nach links eines lkw .'

In [10]:
from torchtext.datasets import Multi30k

In [11]:
train, valid, test = Multi30k('./datasets/Multi30k/')

In [16]:
from train import calculate_bleu_score

In [17]:
calculate_bleu_score(
    decoder=decoder, 
    dataloader=[(trg, src) for trg, src in test], 
    src_preprocess_fn=lambda x: [token.text for token in en_tokenizer(x.lower())], 
    trg_preprocess_fn=lambda x: [token.text for token in de_tokenizer(x.lower())], 
    max_len=SEQ_LEN
)

100%|██████████| 1000/1000 [03:29<00:00,  4.77it/s]

BLEU-4 corpus score = 0.2122980863894124, corpus length = 998.





0.2122980863894124