In [None]:
from fastai2.basics import *
from transformers import AutoTokenizer
from fastai_transformers_utils.all import *

from nmt_small.models.patch import *
from nmt_small.models.tran2tran import TranEncoder, TranDecoder, Tran2Tran, GeneratedTran2Tran
from nmt_small.data.tatoeba import get_datasets
from nmt_small.metrics import compute_bleu

In [None]:
# all_slow

In [None]:
tok_data_loc = './data/tatoeba/tok_cmn.csv'
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'
enc_seq_len = 50
dec_seq_len = 40

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

## Datasets

In [None]:
small_dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len, pct=0.2)
dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len)
len(small_dss.train), len(dss.train)

(3392, 16964)

In [None]:
dss.train[10], dss.decode(dss.train[10])

((TensorText([ 101,  800, 6651,  749,  511,  102,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0]),
  TensorText([50257,  1544,  4966,    13, 50256, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258]),
  TensorText([ 1544,  4966,    13, 50256, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 5

In [None]:
# dls = dss.dataloaders(bs=2)
# for x in dls.train:
#     print(x[0].shape, x[0].dtype, x[0].device, type(x[0]))
#     print(x[1].shape, x[1].dtype, x[0].device, type(x[1]))
#     print(x[2].shape, x[2].dtype, x[0].device, type(x[2]))
#     break

## Model

In [None]:
enc_max_pos_id = enc_seq_len+10
enc_vocab_size = len(enc_tokenizer)
enc_pad_id = enc_tokenizer.pad_token_id

dec_max_pos_id = dec_seq_len+10
dec_vocab_size = len(dec_tokenizer)
dec_pad_id = dec_tokenizer.pad_token_id

embeded_size = 256
num_head = 2
num_encoder_layers = 2
num_decoder_layers = 2
dim_feedforward = 512
drop_p = 0.1

In [None]:
encoder = TranEncoder(enc_vocab_size, embeded_size, enc_max_pos_id, enc_pad_id)
decoder = TranDecoder(dec_vocab_size, embeded_size, dec_max_pos_id, dec_pad_id)
tran2tran = Tran2Tran(encoder, decoder, enc_pad_id)

## Learner and Train

In [None]:
dls = small_dss.dataloaders(bs=64)
# dls = dss.dataloaders(bs=64)
learn = Learner(dls, 
                tran2tran, 
                loss_func=CrossEntropyLossFlat(ignore_index=dec_pad_id), 
                opt_func=Adam,
                metrics=[accuracy, Perplexity()],
               ).to_fp16()

In [None]:
learn.fit_one_cycle(4, 5e-4)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,4.960009,4.365407,0.077918,78.681381,01:23
1,3.909053,3.646549,0.091069,38.342098,01:22
2,3.410864,3.293607,0.099853,26.939867,01:22
3,3.098508,3.188998,0.1029,24.264109,01:19


## Bleu

In [None]:
generate_args = GenerateArgs(   
    max_length=20,
#     do_sample=True,
    num_beams=3,
    temperature=1.0,
    repetition_penalty=1,
    length_penalty=1.0,
)
generated_tran2tran = GeneratedTran2Tran(tran2tran, enc_tokenizer, dec_tokenizer)
# dls = dss.dataloaders(bs=64)

In [None]:
compute_bleu(generated_tran2tran, generate_args, dec_tokenizer, dls.valid)

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


0.01570638370580385

## Generate

In [None]:
src_strs = ['你确定？', 
            '找到汤姆。', 
            '帮帮我们吧！',
            '坚持。']
tgt_strs = ["Really?",
           "Get Tom.",
           "Help us.",
           "Hold on."]

In [None]:
result = generated_tran2tran.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

['Are you busy?',
 'Tom is going to see.',
 "Let's play the door.",
 "There's a car."]

In [None]:
src_strs = ['我很高興再次見到你。', 
            '我有點累。', 
            '我不記得寄過信了。',
            '它是我兄弟的。']
tgt_strs = ["I'm very glad to see you again.",
           "I'm a little tired.",
           "I don't remember mailing the letter.",
           "It's my brother's."]

In [None]:
result = generated_tran2tran.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

["I'm glad to see you.",
 "I'm afraid of it.",
 "I don't think it's too much.",
 "It's my sister."]