In [1]:
from fastai2.basics import *
from transformers import AutoTokenizer
from fastai_transformers_utils.all import *

from nmt_small.models.core import *
from nmt_small.models.tran2tran import TranEncoder, TranDecoder, Tran2Tran, GeneratedTran2Tran
from nmt_small.data.tatoeba import get_datasets
from nmt_small.metrics import compute_bleu

In [2]:
# all_slow

In [3]:
tok_data_loc = './data/tatoeba/tok_cmn.csv'
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'
enc_seq_len = 50
dec_seq_len = 40

In [4]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

## DataLoaders

In [5]:
dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len)
# dss.train[10], dss.decode(dss.train[10])

In [6]:
# dls = dss.dataloaders(bs=2)
# for x in dls.train:
#     print(x[0].shape, x[0].dtype, x[0].device, type(x[0]))
#     print(x[1].shape, x[1].dtype, x[0].device, type(x[1]))
#     print(x[2].shape, x[2].dtype, x[0].device, type(x[2]))
#     break

## Model

In [7]:
enc_max_pos_id = enc_seq_len+10
enc_vocab_size = len(enc_tokenizer)
enc_pad_id = enc_tokenizer.pad_token_id

dec_max_pos_id = dec_seq_len+10
dec_vocab_size = len(dec_tokenizer)
dec_pad_id = dec_tokenizer.pad_token_id

embeded_size = 256
num_head = 2
num_encoder_layers = 2
num_decoder_layers = 2
dim_feedforward = 512
drop_p = 0.1

In [33]:
encoder = TranEncoder(enc_vocab_size, embeded_size, enc_max_pos_id, enc_pad_id)
decoder = TranDecoder(dec_vocab_size, embeded_size, dec_max_pos_id, dec_pad_id)
tran2tran = Tran2Tran(encoder, decoder, enc_pad_id)

## Learner and Train

In [34]:
dls = dss.dataloaders(bs=64)
learn = Learner(dls, 
                tran2tran, 
                loss_func=CrossEntropyLossFlat(ignore_index=dec_pad_id), 
                opt_func=Adam,
                metrics=[accuracy, Perplexity()],
               ).to_fp16()

In [35]:
learn.fit_one_cycle(4, 0.0005)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,4.953203,4.350276,0.077128,77.499809,01:15
1,3.930136,3.668665,0.091063,39.199528,01:15
2,3.396621,3.298762,0.099817,27.079096,01:15
3,3.132953,3.198549,0.102947,24.496948,01:15


## Bleu

In [40]:
generate_args = GenerateArgs(   
    max_length=20,
#     do_sample=True,
    num_beams=3,
    temperature=1.0,
    repetition_penalty=1,
    length_penalty=1.0,
)
generated_tran2tran = GeneratedTran2Tran(tran2tran, enc_tokenizer, dec_tokenizer)
dls = dss.dataloaders(bs=64)

In [41]:
compute_bleu(generated_tran2tran, generate_args, dec_tokenizer, dls.valid)

0.0651092288036081

## Generate

In [36]:
src_strs = ['你确定？', 
            '找到汤姆。', 
            '帮帮我们吧！',
            '坚持。']
tgt_strs = ["Really?",
           "Get Tom.",
           "Help us.",
           "Hold on."]

In [37]:
result = generated_tran2tran.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

['Are you going to go?',
 'Tom is going to go.',
 "Let's come to see us.",
 "It's too."]

In [38]:
src_strs = ['我很高興再次見到你。', 
            '我有點累。', 
            '我不記得寄過信了。',
            '它是我兄弟的。']
tgt_strs = ["I'm very glad to see you again.",
           "I'm a little tired.",
           "I don't remember mailing the letter.",
           "It's my brother's."]

In [39]:
result = generated_tran2tran.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

["I'm sorry to see you.",
 'I have a lot of money.',
 "I don't think I'm very well.",
 "It's my brother."]