In [2]:
import torch
from model import Transformer
from data_utils import create_dataloader, generate_mask_src, generate_mask_trg
from tqdm import tqdm

In [3]:
BATCH_SIZE = 32
SEQ_LEN = 64
EPOCHS = 5
PAD_TOKEN_ID = 1

In [4]:
dataloader, valid_iter, en_vocab, de_vocab, en_tokenizer, de_tokenizer = create_dataloader(batch_size=BATCH_SIZE, seq_len=SEQ_LEN)
model = Transformer(seq_len=SEQ_LEN, src_vocab_size=len(de_vocab), trg_vocab_size=len(en_vocab), dropout_p=0.1)

Filtering the dataset. Initial size: 1000000
Removing special characters...
Filtered successfully. Final size: 938171


100%|██████████| 891262/891262 [01:24<00:00, 10558.77it/s]
100%|██████████| 891262/891262 [02:13<00:00, 6695.14it/s]


In [5]:
import pickle

with open('ru_vocab.pkl', 'wb') as f:
    pickle.dump(de_vocab, f)

In [6]:
import pickle

with open('en_vocab.pkl', 'wb') as f:
    pickle.dump(en_vocab, f)

In [7]:
_ = model.cuda()

In [7]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4, betas=[0.9, 0.98], eps=1e-9)
loss_obj = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_TOKEN_ID, label_smoothing=0.1)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    iterator = tqdm(enumerate(dataloader))
    for i, (trg_seq, src_seq) in iterator:
        optim.zero_grad()

        trg_input_seq = trg_seq[:, :-1]
        trg_label_seq = trg_seq[:, 1:]
        trg_input_seq = trg_input_seq.cuda(); src_seq = src_seq.cuda()
        trg_label_seq = trg_label_seq.cuda()

        src_mask, n_src_tokens = generate_mask_src(src_seq, PAD_TOKEN_ID)
        trg_mask, n_trg_tokens = generate_mask_trg(trg_input_seq, PAD_TOKEN_ID)

        enc_out, logits = model(src_seq, trg_input_seq, src_mask, trg_mask)
        loss = loss_obj(logits.view(-1, logits.shape[-1]), trg_label_seq.view(-1)) / float(BATCH_SIZE)

        loss.backward()
        optim.step()
        iterator.set_postfix_str(f"Epoch/Iteration {epoch}/{i}. Loss: {loss}")

In [5]:
from datasets import YandexDataset

In [6]:
train_iter, valid_iter, test_iter = YandexDataset('data/datasets/Yandex').get_iters()

Filtering the dataset. Initial size: 1000000
Removing special characters...
Filtered successfully. Final size: 938171


In [18]:

train_iter_ = iter(valid_iter)

In [60]:
trg, src = next(train_iter_)
trg, src

('we were even given additional towels and special sheets for our baby.',
 'все просьбы к персоналу выполнялись быстро и безотказно.')

In [12]:
from decoding import GreedyDecoder

In [13]:
decoder = GreedyDecoder(
    model, 
    lambda sent: [token.text for token in de_tokenizer(sent)], 
    src_vocab=de_vocab, trg_vocab=en_vocab, 
    eos_token_id=3, sos_token_id=2, pad_token_id=PAD_TOKEN_ID, 
    max_seq_length=SEQ_LEN
)

In [14]:
decoder.decode('Я языковая модель, которая была обучена переводить русский на английский')

'i am an language model , which was inspired by english translation of russian'

In [26]:
de_vocab['покупает']

14817

In [195]:
de_vocab['верну']

0

In [127]:
#torch.save(model.state_dict(), 'weights/weights.pth')

In [8]:
model = torch.nn.DataParallel(model)

In [9]:
model.load_state_dict(torch.load('weights/Yandex/weights_filtered.pth', map_location='cuda:0'), strict=False)

<All keys matched successfully>

In [10]:
model = model.module

In [11]:
torch.save(model.state_dict, 'yandex_weights.pth')

In [12]:
from nltk.translate.bleu_score import corpus_bleu

def calculate_bleu_score(decoder, dataloader, src_preprocess_fn, trg_preprocess_fn, max_len=10):
    preds = []
    targets = []
    for trg, src in tqdm(list(dataloader)):
        out = decoder.decode(src)
        preds.append(trg_preprocess_fn(out))
        targets.append([trg_preprocess_fn(trg)])

    bleu_score = corpus_bleu(targets, preds)
    print(f'BLEU-4 corpus score = {bleu_score}, corpus length = {len(targets)}.')
    return bleu_score

In [14]:
calculate_bleu_score(decoder, [x for x in test_iter][:1000], lambda x: [token.text for token in de_tokenizer(x)], lambda x: [token.text for token in en_tokenizer(x)], max_len=64)

100%|██████████| 1000/1000 [08:08<00:00,  2.05it/s]

BLEU-4 corpus score = 0.23002674009458673, corpus length = 1000.





0.23002674009458673