# Assignment 7

Train a Transformer model for Machine Translation from Russian to English.  
Dataset: http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz   
Make all source and target text to lower case.  
Use following tokenization for english:  
```
import sentencepiece as spm

...
spm.SentencePieceTrainer.Train('--input=data/text.en --model_prefix=bpe_en --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

tok_en = spm.SentencePieceProcessor()
tok_en.load('bpe_en.model')

TGT = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_en.encode_as_pieces(x),
    batch_first=True,
)

...
TGT.build_vocab(..., min_freq=5)
...

```
Score: corpus-bleu `nltk.translate.bleu_score.corpus_bleu`  
Use last 1000 sentences for model evalutation (test dataset).  
Use your target sequence tokenization for BLEU score.  
Use max_len=50 for sequence prediction.  


Hint: You may consider much smaller model, than shown in the example.  

Baselines:  
[4 point] BLEU = 0.05  
[6 point] BLEU = 0.10  
[9 point] BLEU = 0.15  

[1 point] Share weights between target embeddings and output dense layer. Notice, they have the same shape.


Readings:
1. BLUE score how to https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
1. Transformer code and comments http://nlp.seas.harvard.edu/2018/04/03/attention.html

In [38]:
!pip install sentencepiece



In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from torchtext import datasets, data
from tqdm.notebook import tqdm
import sentencepiece as spm
from transformer import *
import re

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

In [0]:
def preproc(txt):
    txt = txt.lower()
    txt = re.sub('\\xa0', ' ', txt)
    txt = re.sub(r'\\u0027', "'", txt)
    txt = re.sub(r'&[a-z]{0,7};', ' ', txt)
    return txt

In [47]:
# tokenize english 
with open('data/news-commentary-v13.ru-en.en') as f:
    with open('data/text.en', 'w') as out:
            out.write(preproc(f.read()))
        
spm.SentencePieceTrainer.Train('--input=data/text.en --model_prefix=bpe_en --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

True

In [48]:
# tokenize russian 
with open('data/news-commentary-v13.ru-en.ru') as f:
    with open('data/text.ru', 'w') as out:
            out.write(preproc(f.read()))
        
spm.SentencePieceTrainer.Train('--input=data/text.ru --model_prefix=bpe_ru --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

True

In [0]:
tok_ru = spm.SentencePieceProcessor()
tok_ru.load('bpe_ru.model')

tok_en = spm.SentencePieceProcessor()
tok_en.load('bpe_en.model')

SRC = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_ru.encode_as_pieces(x),
    batch_first=True,
)

TGT = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_en.encode_as_pieces(x),
    batch_first=True,
)

fields = (('src', SRC), ('tgt', TGT))

In [50]:
with open('data/text.ru') as f:
    src_snt = list(map(str.strip, f.readlines()))
    
with open('data/text.en') as f:
    tgt_snt = list(map(str.strip, f.readlines()))
    
examples = [data.Example.fromlist(x, fields) for x in tqdm(zip(src_snt, tgt_snt))]
test = data.Dataset(examples[-1000:], fields)
train, valid = data.Dataset(examples[:-1000], fields).split(0.9)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [51]:
print('src: ' + " ".join(train.examples[100].src))
print('tgt: ' + " ".join(train.examples[100].tgt))

src: ▁но , ▁если ▁бы ▁правительство ▁или ▁отдельные ▁люди ▁использовали ▁ э то ▁как ▁оправдание ▁уменьшения ▁помощи ▁самым ▁бедным ▁людям ▁в ▁мире , ▁они ▁только ▁при умно жили ▁бы ▁серьезность ▁проблемы ▁для ▁мира ▁в ▁целом .
tgt: ▁but ▁if ▁governments ▁or ▁individuals ▁use ▁this ▁as ▁an ▁e x cuse ▁to ▁reduce ▁assistance ▁to ▁the ▁world ’ s ▁poorest ▁people , ▁they ▁would ▁only ▁multiply ▁the ▁seriousness ▁of ▁the ▁problem ▁for ▁the ▁world ▁as ▁a ▁whole .


In [52]:
len(train), len(valid), len(test)

(210743, 23416, 1000)

In [0]:
TGT.build_vocab(train, min_freq=5)
SRC.build_vocab(train, min_freq=5)

In [0]:
from transformer import make_model, Batch

    
class BucketIteratorWrapper(DataLoader):
    __initialized = False

    def __init__(self, iterator: data.Iterator):
#         super(BucketIteratorWrapper,self).__init__()
        self.batch_size = iterator.batch_size
        self.num_workers = 1
        self.collate_fn = None
        self.pin_memory = False
        self.drop_last = False
        self.timeout = 0
        self.worker_init_fn = None
        self.sampler = iterator
        self.batch_sampler = iterator
        self.__initialized = True

    def __iter__(self):
        return map(
            lambda batch: Batch(batch.src, batch.tgt, pad=TGT.vocab.stoi['<pad>']),
            self.batch_sampler.__iter__()
        )

    def __len__(self):
        return len(self.batch_sampler)
    
class MyCriterion(nn.Module):
    def __init__(self, pad_idx):
        super(MyCriterion, self).__init__()
        self.pad_idx = pad_idx
        self.criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=pad_idx)
        
    def forward(self, x, target):
        x = x.contiguous().permute(0,2,1)
        ntokens = (target != self.pad_idx).data.sum()
        
        return self.criterion(x, target) / ntokens

In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [0]:
torch.cuda.empty_cache()

batch_size = 64
num_epochs = 3

train_iter, valid_iter, test_iter = data.BucketIterator.splits((train, valid, test), 
                                              batch_sizes=(batch_size, batch_size, batch_size), 
                                  sort_key=lambda x: len(x.src),
                                  shuffle=True,
                                  device=DEVICE,
                                  sort_within_batch=False)
                                  
train_iter = BucketIteratorWrapper(train_iter)
valid_iter = BucketIteratorWrapper(valid_iter)
test_iter = BucketIteratorWrapper(test_iter)

#def make_model(src_vocab, tgt_vocab, N=6, 
#               d_model=512, d_ff=2048, h=8, dropout=0.1):


model = make_model(len(SRC.vocab), len(TGT.vocab), N = 4, d_model=256, d_ff=512)
model = model.to(DEVICE)
criterion = MyCriterion(pad_idx=TGT.vocab.stoi["<pad>"])
optimizer = get_std_opt(model)
#scheduler = <TODO>

# share weights
model.generator.weight = model.tgt_embed[0].lut.weight

In [81]:
def train_epoch(data_iter, model, criterion):
    total_loss = 0
    data_iter = tqdm(data_iter)
    counter = 0
    for batch in data_iter:
        #model.zero_grad()
        out = model.forward(batch)
        loss = criterion(out, batch.tgt_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        optimizer.step()

        total_loss += loss
        data_iter.set_postfix(loss = loss)
        counter +=1
        
    total_loss /= counter
    return total_loss


def valid_epoch(data_iter, model, criterion):
    total_loss = 0
    data_iter = tqdm(data_iter)
    counter = 0
    for batch in data_iter:
        
        out = model.forward(batch)
        loss = criterion(out, batch.tgt_y)
        
        total_loss += loss
        data_iter.set_postfix(loss = loss)
        counter +=1
        
    total_loss /= counter
    return total_loss


for epoch in range(num_epochs):
    model.train()
    loss = train_epoch(train_iter, model, criterion)
    print('train', loss)
    
    model.eval()
    with torch.no_grad():
        loss = valid_epoch(valid_iter, model, criterion)
        #scheduler.step(loss)
        print('valid', loss)

HBox(children=(IntProgress(value=0, max=3293), HTML(value='')))

KeyboardInterrupt: ignored

In [0]:
start_symbol = TGT.vocab.stoi["<s>"]
end_symbol = TGT.vocab.stoi["</s>"]

In [0]:
def beam_search(model, src, src_mask, max_len=10, k=5):
    memory = model.encode(src, src_mask)
    ys = torch.ones(src.size(0), 1).fill_(start_symbol).long().to(src.device)
    beam = [(0, ys)]

    for l in range(max_len):
        candidates = []
        for snt_proba, snt in beam:
            proba = model.decode(snt.long(), subsequent_mask(snt.size(1)).type_as(src.data),
                           memory,
                           src_mask)
            _, pos_k = torch.topk(proba[:, -1, :], k=k, dim=-1)

            candidates += [(sum([proba[0, i, snt[0, i]].item() for i in range(l)]) + proba[0, -1, next_word].item(), torch.cat([snt, next_word.resize(snt.size(0), 1)], dim=1))
                for next_word in pos_k.view(k, -1)]

        beam = sorted(candidates, key=lambda x: x[0])[-k:]

    return beam

In [132]:
model.eval()
with torch.no_grad():
    for idx, batch in enumerate(valid_iter):
        src = batch.src[:1]
        src_key_padding_mask = src != SRC.vocab.stoi["<pad>"]
        beam = beam_search(model, src, src_key_padding_mask, k=5)
        
        seq = []
        for i in range(1, src.size(1)):
            sym = SRC.vocab.itos[src[0, i]]
            if sym == "</s>": break
            seq.append(sym)
        seq = tok_ru.decode_pieces(seq)
        print("\nSource:", seq)
        
        print("Translation:")
        for pred_proba, pred in beam:                
            seq = []
            for i in range(1, pred.size(1)):
                sym = TGT.vocab.itos[pred[0, i]]
                if sym == "</s>": break
                seq.append(sym)
            seq = tok_en.decode_pieces(seq)
            print(f"pred {pred_proba:.2f}:", seq)
                
        seq = []
        for i in range(1, batch.tgt.size(1)):
            sym = TGT.vocab.itos[batch.tgt[0, i]]
            if sym == "</s>": break 
            seq.append(sym)
        seq = tok_en.decode_pieces(seq)
        print("Target:", seq)
        if idx == 4:
            break




Source: ответ ирану
Translation:
pred 89.57: 
pred 89.71: 
pred 89.84: 
pred 89.90: 
pred 94.34: 
Target: answering iran

Source: полиции нет.
Translation:
pred 65.75: but but but it.
pred 65.78: but but but it.
pred 66.39: but but but it.
pred 69.88: but but but it.
pred 71.17: but but but it.
Target: police are unavailable.

Source: дни двойной рецессии
Translation:
pred 88.61: 
pred 88.76: 
pred 88.89: 
pred 88.95: 
pred 93.31: 
Target: double-dip days

Source: скромная миссия европы
Translation:
pred 90.46: 
pred 90.60: 
pred 90.72: 
pred 90.78: 
pred 95.34: 
Target: europe's modest mission

Source: будущее азии после цунами
Translation:
pred 66.63: but in 
pred 67.00: but in 
pred 67.12: but in 
pred 67.40: but in 
pred 69.34: but in 
Target: asia’s post-tsunami future


In [0]:
from nltk.translate.bleu_score import corpus_bleu
from nltk import translate

In [108]:
hypotheses = []
references = []


model.eval()
with torch.no_grad():
  for batch in tqdm(test_iter):
    for sent in range(len(batch.src)):
      src = batch.src[sent:sent+1]
      src_key_padding_mask = src != SRC.vocab.stoi["<pad>"]
      beam = beam_search(model, src, src_key_padding_mask, k=5)
      for pred_proba, pred in beam[:1]:                
        seqns = []
        for i in range(1, pred.size(1)):
          sym = TGT.vocab.itos[pred[0, i]]
          if sym == "</s>": 
            break
          seqns.append(sym)
        hypotheses.append(seqns)

        my_tgt = batch.tgt[sent:sent+1].tolist()[0]
        refrnc = []
        for i in range(1, batch.tgt.size(1)):
          sym =  TGT.vocab.itos[my_tgt[i]]
          if sym == "</s>" or sym  == "<pad>": 
            break
          refrnc.append(sym)

        references.append(refrnc)

HBox(children=(IntProgress(value=0, max=16), HTML(value='')))






In [123]:
len(hypotheses)

1000

In [124]:
corpus_bleu(references, hypotheses, 
            smoothing_function=translate.bleu_score.SmoothingFunction().method3,
            auto_reweigh=True
           )

0.01347807093236565
