In [1]:
import locale
def get_enc():
  return "UTF-8"
locale.getpreferredencoding = get_enc

In [2]:
%matplotlib inline

In [3]:
!pip install torchdata -q
!pip install portalocker -q


# Language Translation with ``nn.Transformer`` and torchtext

This tutorial shows:
    - How to train a translation model from scratch using Transformer.
    - Use torchtext library to access  [Multi30k](http://www.statmt.org/wmt16/multimodal-task.html#task1)_ dataset to train a German to English translation model.


## Data Sourcing and Processing

[torchtext library](https://pytorch.org/text/stable/)_ has utilities for creating datasets that can be easily
iterated through for the purposes of creating a language translation
model. In this example, we show how to use torchtext's inbuilt datasets,
tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
[Multi30k dataset from torchtext library](https://pytorch.org/text/stable/datasets.html#multi30k)_
that yields a pair of source-target raw sentences.

To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.




In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from torch.utils.data import random_split
from typing import Iterable, List
import re

SRC_LANGUAGE = 'crh'
TGT_LANGUAGE = 'rus'

token_transform = {}
vocab_transform = {}

In [5]:
def read_dataset(path='data.txt'):
  dataset = []
  for line in open(path):
    pair = line.split('\t')
    if len(pair) < 2:
      continue
    src, trg = line.split('\t')
    trg = re.sub(r'\s(ст\s|ч\s)+', ' ', trg)
    trg = trg.strip('\n').strip('')
    data = {}
    data[TGT_LANGUAGE] = trg
    data[SRC_LANGUAGE] = src
    dataset.append(data)
  return dataset

def replicate_phrases(dataset = []):
  ext = []
  for pair in dataset:
    trg = pair['rus']
    words = len(trg.split(' '))
    if words > 0 and words < 6:
      ext.append(pair)
  dataset.extend(ext)
  return dataset

dataset = read_dataset('./data_clean_final.tsv')
proportions = [.8, .15, .05] 
lengths = [int(p * len(dataset)) for p in proportions]
lengths[-1] = len(dataset) - sum(lengths[:-1])
train_data, valid_data, test_data = random_split(dataset, lengths)
print(len(train_data))
train_data = replicate_phrases(list(train_data))
print(len(train_data))


23281
31422


In [9]:
from random import shuffle 

shuffle(train_data)

In [11]:
train_data[1]

{'rus': 'контроль за выполнением настоящего постановления возложить на комитет государ ственного совета республики крым по инвестиционной и налоговой политике и комитет государственного совета республики крым по информационной полити ке информационным технологиям и связи',
 'crh': 'мезкюр къарар ерине кетирильмеси узеринде незарет къырым джумхуриети девлет шурасынынъ ятырым ве берги сиясети боюнджа комитети ве къырым джум хуриети девлет шурасынынъ малюмат сиясети малюмат технологиялары ве алякъа боюнджа комитетине авале этильсин'}

In [12]:
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.normalizers import Lowercase, NFKC, Sequence
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing

In [13]:
# First we create an empty Byte-Pair Encoding model (i.e. not trained model)
tokenizer_src = Tokenizer(BPE(unk_token = '<unk>'))
tokenizer_trg = Tokenizer(BPE(unk_token = '<unk>'))

# Then we enable lower-casing and unicode-normalization
# The Sequence normalizer allows us to combine multiple Normalizer that will be
# executed in order.
tokenizer_src.normalizer = Sequence([
  NFKC(),
  Lowercase()
])
tokenizer_trg.normalizer = Sequence([
  NFKC(),
  Lowercase()
])
tokenizer_src.add_special_tokens(['<unk>', '<sos>', '<eos>', '<pad>'])
tokenizer_trg.add_special_tokens(['<unk>', '<sos>', '<eos>', '<pad>'])

tokenizer_src.enable_padding(pad_token = '<pad>')
tokenizer_trg.enable_padding(pad_token = '<pad>')
# Our tokenizer also needs a pre-tokenizer responsible for converting the input to a ByteLevel representation.
tokenizer_src.pre_tokenizer = ByteLevel()
tokenizer_trg.pre_tokenizer = ByteLevel()

# And finally, let's plug a decoder so we can recover from a tokenized input to the original one
tokenizer_src.decoder = ByteLevelDecoder()
tokenizer_trg.decoder = ByteLevelDecoder()

tokenizer_src.decoder = ByteLevelDecoder()
tokenizer_trg.decoder = ByteLevelDecoder()

tokenizer_src.post_processor = TemplateProcessing(
    single="<sos> $A <eos>",
    #pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[("<sos>", 1), ("<eos>", 2)],
)

tokenizer_trg.post_processor = TemplateProcessing(
    single="<sos> $A <eos>",
    #pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[("<sos>", 1), ("<eos>", 2)],
)

In [14]:
def create_vocab(tokenizer, data, lang, save_folder):  
  def yield_tokens(data, lang):
    for sample in data:
      yield sample[lang]
  # We initialize our trainer, giving him the details about the vocabulary we want to generate
  trainer = BpeTrainer(vocab_size = 15000, show_progress=True, 
                       specials=['<unk>', '<sos>', '<eos>', '<pad>'],
                       initial_alphabet=ByteLevel.alphabet())
  tokenizer.train_from_iterator(yield_tokens(data, lang), trainer=trainer)

  print("Trained vocab size: {}".format(tokenizer.get_vocab_size()))
  tokenizer.model.save(save_folder)

In [16]:
!mkdir trg src

mkdir: cannot create directory ‘trg’: File exists
mkdir: cannot create directory ‘src’: File exists


In [15]:
create_vocab(tokenizer_trg, train_data, SRC_LANGUAGE, './trg')
create_vocab(tokenizer_src, train_data, TGT_LANGUAGE, './src')

token_transform[SRC_LANGUAGE] = tokenizer_src
token_transform[TGT_LANGUAGE] = tokenizer_trg

vocab_transform[SRC_LANGUAGE] = token_transform[SRC_LANGUAGE].get_vocab()
vocab_transform[TGT_LANGUAGE] = token_transform[TGT_LANGUAGE].get_vocab()

Trained vocab size: 15004
Trained vocab size: 15004


In [17]:
tokenizer_src.model = BPE('./src/vocab.json', './src/merges.txt')
tokenizer_trg.model = BPE('./trg/vocab.json', './trg/merges.txt')

  tokenizer_src.model = BPE('./src/vocab.json', './src/merges.txt')
  tokenizer_trg.model = BPE('./trg/vocab.json', './trg/merges.txt')


In [18]:
print(tokenizer_src.encode(train_data[4]['crh']).tokens)
print(tokenizer_trg.encode(train_data[4]['rus']).tokens)
print(tokenizer_src.encode(train_data[4]['crh']).ids)
print(tokenizer_src.decode(tokenizer_src.encode(train_data[4]['rus']).ids))
print(tokenizer_trg.decode(tokenizer_trg.encode(train_data[4]['crh']).ids))

['<sos>', 'ĠÑĤÐµ', 'ÑĦ', 'ÑģÐ¸', 'ÑĢ', 'Ð»ÐµÑĢ', 'Ð´Ðµ', 'ĠÑĤÐ°ÑĢ', 'ÑĤÑĭ', 'ÑĪ', 'Ð¼Ð°', '<eos>']
['<sos>', 'ĠÐ½Ðµ', 'ĠÑģÐ¿', 'Ð¾ÑĢ', 'ÑĮ', 'ĠÐ²', 'ĠÐºÐ¾Ð¼', 'Ð¼ÐµÐ½', 'ÑĤÐ°ÑĢ', 'Ð¸Ñı', 'Ñħ', '<eos>']
[1, 445, 431, 676, 264, 9808, 484, 12608, 638, 380, 486, 2]
 не спорь в комментариях
 тефсирлерде тартышма


In [19]:
# token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='ru_core_news_sm')
# token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# # helper function to yield list of tokens
# def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
#     for data_sample in data_iter:
#         yield token_transform[language](data_sample[language])

# # Define special symbols and indices
UNK_IDX = 0
SOS_IDX = 1 
EOS_IDX = 2
PAD_IDX = tokenizer_src.padding['pad_id']
# # Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

# for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
#     # Training data Iterator
#     train_iter = iter(dataset) # todo: change
#     # Create torchtext's Vocab object
#     vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_data, ln),
#                                                     min_freq=1,
#                                                     specials=special_symbols,
#                                                     special_first=True)

# # Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# # If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
# for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
#   vocab_transform[ln].set_default_index(UNK_IDX)

In [20]:
# print(f"Vocab has {len(vocab_transform['en'])} english words and {len(vocab_transform['ru'])} russian words")

## Seq2Seq Network using Transformer

Transformer is a Seq2Seq model introduced in [“Attention is all you
need”](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)_
paper for solving machine translation tasks.
Below, we will create a Seq2Seq network that uses Transformer. The network
consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
into corresponding tensor of input embeddings. These embedding are further augmented with positional
encodings to provide position information of input tokens to the model. The second part is the
actual [Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)_ model.
Finally, the output of the Transformer model is passed through linear layer
that gives unnormalized probabilities for each token in the target language.




In [39]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

During training, we need a subsequent word mask that will prevent the model from looking into
the future words when making predictions. We will also need masks to hide
source and target padding tokens. Below, let's define a function that will take care of both.




In [40]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Let's now define the parameters of our model and instantiate the same. Below, we also
define our loss function which is the cross-entropy loss and the optimizer used for training.




In [41]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 96
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

## Collation

As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
defined previously. Below we define our collate function that converts a batch of raw strings into batch tensors that
can be fed directly into our model.




In [42]:
from torch.nn.utils.rnn import pad_sequence


# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([SOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

def my_transform(tokenizer, sample):
  return tensor_transform(tokenizer.encode(sample).ids)

# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for sample in batch:
        src_sample, tgt_sample = sample[SRC_LANGUAGE], sample[TGT_LANGUAGE]
        src_batch.append(my_transform(token_transform[SRC_LANGUAGE], src_sample.rstrip("\n")))
        tgt_batch.append(my_transform(token_transform[TGT_LANGUAGE], tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

Let's define training and evaluation loop that will be called for each
epoch.




In [43]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [44]:
import gc
torch.cuda.empty_cache()
gc.collect()

104

Now we have all the ingredients to train our model. Let's do it!




In [45]:
from timeit import default_timer as timer
NUM_EPOCHS = 25

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))


Epoch: 1, Train loss: 4.769, Val loss: 2.978, Epoch time = 97.574s
Epoch: 2, Train loss: 2.712, Val loss: 2.293, Epoch time = 99.285s
Epoch: 3, Train loss: 2.226, Val loss: 1.954, Epoch time = 99.458s
Epoch: 4, Train loss: 1.948, Val loss: 1.745, Epoch time = 98.997s
Epoch: 5, Train loss: 1.757, Val loss: 1.607, Epoch time = 97.580s
Epoch: 6, Train loss: 1.615, Val loss: 1.503, Epoch time = 99.393s
Epoch: 7, Train loss: 1.500, Val loss: 1.420, Epoch time = 99.070s
Epoch: 8, Train loss: 1.403, Val loss: 1.356, Epoch time = 98.769s
Epoch: 9, Train loss: 1.321, Val loss: 1.306, Epoch time = 99.052s
Epoch: 10, Train loss: 1.248, Val loss: 1.266, Epoch time = 97.806s
Epoch: 11, Train loss: 1.183, Val loss: 1.229, Epoch time = 98.956s
Epoch: 12, Train loss: 1.125, Val loss: 1.197, Epoch time = 99.356s
Epoch: 13, Train loss: 1.071, Val loss: 1.173, Epoch time = 102.528s
Epoch: 14, Train loss: 1.022, Val loss: 1.147, Epoch time = 103.725s
Epoch: 15, Train loss: 0.978, Val loss: 1.133, Epoch ti

In [46]:
token_transform[SRC_LANGUAGE].encode("я ем тост с маслом").ids

[1, 539, 7349, 1093, 269, 283, 3352, 1812, 2]

In [47]:

# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask).to(DEVICE)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        # memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    # model.eval()
    src = torch.tensor(token_transform[SRC_LANGUAGE].encode(src_sentence).ids).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 5, start_symbol=SOS_IDX).flatten()
    return "".join(token_transform[TGT_LANGUAGE].decode(list(tgt_tokens))).replace("<sos>", "").replace("<eos>", "")

In [54]:
test_data[0]

{'rus': 'нам нужен номер на двоих', 'crh': 'визге эки кишилик номер керек'}

In [55]:
for example in list(test_data)[600:660]:
  print("<<<", example['crh'])
  print("===", example['rus'])
  print(">>>", translate(transformer, example['crh']))
  print("*" * 20)

<<< къырым джумхуриети девлет шур асынынъ къар ары айванлар алеми акъкъында къырым джумхуриети къанунынынъ маддесине денъишмелер кирсетильмеси акъкъында къырым джумхуриетининъ къануны акъкъында къырым джумхуриети анаясасынынъ маддесиндеки пунктына маддесиндеки къысмына къырым джумхуриети девлет шурасы регламентининъ мад десиндеки къысмына маддесине мувафыкъ къырым джумхуриетининъ девлет шурасы къарар бере
=== постановление государственного совета республики крым о законе республики крым о внесении изменений в статью закона республики крым о животном мире в соответствии с пунктом стат ьи частью статьи конституции республики крым частью статьи статьей регламента государственного совета республики крым государственный совет республики крым постановляет
>>>  постановление государственного совета республики крым о законе республики крым о внесении изменений в статью закона республики крым о животном мире в соответствии с пунктом статьи частью статьи конституции республики крым частью статьи

In [50]:
translate(transformer, 'внесенный главой республики крым аксёновым государственный совет республики крым постановляет')

' капитальный ремонт советский пер архюродная'

In [51]:
from nltk.translate import bleu_score
import tqdm

target = []
predictions = []

for example in tqdm.tqdm(list(test_data)[:1000]):
  pred = translate(transformer, example['crh']).strip()
  pred = token_transform[TGT_LANGUAGE].encode(pred).tokens
  trg = example['rus'].rstrip("\n").strip()
  trg = token_transform[TGT_LANGUAGE].encode(trg).tokens
  
  target.append(trg)
  predictions.append([pred])


100%|██████████| 1000/1000 [06:56<00:00,  2.40it/s]


In [52]:
bleu_score.corpus_bleu(predictions, target)

0.5183140124796218

In [53]:
torch.save(transformer.state_dict(), "crh_to_rus.weights")

## References

1. Attention is all you need paper.
   https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
2. The annotated transformer. https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding

