In [None]:
from fastai2.basics import *
from transformers import AutoTokenizer
from fastai_transformers_utils.all import *

from nmt_small.models.patch import *
from nmt_small.models.gru2gru import GRUEncoder, GRUDecoder, GRU2GRU, GeneratedGRU2GRU
from nmt_small.data.tatoeba import get_datasets
from nmt_small.metrics import compute_bleu

In [None]:
# all_slow

In [None]:
tok_data_loc = './data/tatoeba/tok_cmn.csv'
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'
enc_seq_len = 50
dec_seq_len = 40

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

## Datasets

In [None]:
small_dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len, pct=0.2)
dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len)
len(small_dss.train), len(dss.train)

(3392, 16964)

In [None]:
dss.train[10], dss.decode(dss.train[10])

((TensorText([ 101,  800, 6651,  749,  511,  102,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0]),
  TensorText([50257,  1544,  4966,    13, 50256, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258]),
  TensorText([ 1544,  4966,    13, 50256, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258,
          50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 50258, 5

In [None]:
# dls = dss.dataloaders(bs=2)
# for x in dls.train:
#     print(x[0].shape, x[0].dtype, x[0].device, type(x[0]))
#     print(x[1].shape, x[1].dtype, x[0].device, type(x[1]))
#     print(x[2].shape, x[2].dtype, x[0].device, type(x[2]))
#     break

## Model

In [None]:
enc_vocab_size = len(enc_tokenizer)
enc_pad_id = enc_tokenizer.pad_token_id

dec_vocab_size = len(dec_tokenizer)
dec_pad_id = dec_tokenizer.pad_token_id

embeded_size = 256
num_encoder_layers = 2
num_decoder_layers = 2
drop_p = 0.1

In [None]:
encoder = GRUEncoder(enc_vocab_size, embeded_size, enc_pad_id, num_encoder_layers, drop_p)
decoder = GRUDecoder(dec_vocab_size, embeded_size, dec_pad_id, num_decoder_layers, drop_p)
gru2gru = GRU2GRU(encoder, decoder, num_encoder_layers, num_decoder_layers)

## Learner and Train

In [None]:
dls = small_dss.dataloaders(bs=128)
# dls = dss.dataloaders(bs=128)
learn = Learner(dls, 
                gru2gru, 
                loss_func=CrossEntropyLossFlat(ignore_index=dec_pad_id), 
                opt_func=Adam,
                metrics=[accuracy, Perplexity()],
               ).to_fp16()

In [None]:
learn.fit_one_cycle(4, 1e-2)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,5.083679,4.304209,0.078808,74.010666,00:48
1,3.637716,3.218691,0.098456,24.995388,00:48
2,2.612543,2.774764,0.111159,16.03484,00:48
3,1.717289,2.69778,0.113753,14.846741,00:48


## Bleu

In [None]:
generate_args = GenerateArgs(   
    max_length=20,
#     do_sample=True,
    num_beams=3,
    temperature=1.0,
    repetition_penalty=1,
    length_penalty=1.0,
)
generated_gru2gru = GeneratedGRU2GRU(gru2gru, enc_tokenizer, dec_tokenizer)

In [None]:
compute_bleu(generated_gru2gru, generate_args, dec_tokenizer, dls.valid)

0.12515677326505098

## Generate

In [None]:
src_strs = ['你确定？', 
            '找到汤姆。', 
            '帮帮我们吧！',
            '坚持。']
tgt_strs = ["Really?",
           "Get Tom.",
           "Help us.",
           "Hold on."]

In [None]:
result = generated_gru2gru.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

['Are you sure you want to do that?',
 ' Tom is Tom to find Tom.',
 "Let's go to us.",
 "It's a bit of children."]

In [None]:
src_strs = ['我很高興再次見到你。', 
            '我有點累。', 
            '我不記得寄過信了。',
            '它是我兄弟的。']
tgt_strs = ["I'm very glad to see you again.",
           "I'm a little tired.",
           "I don't remember mailing the letter.",
           "It's my brother's."]

In [None]:
result = generated_gru2gru.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

["I'm glad to see you again.",
 'I am tired of being tired.',
 "I don't remember the letter better than I remember.",
 "It's my brother is my brother."]