In [None]:
import optuna
from fastai2.basics import *
from transformers import AutoTokenizer

from fastai2_utils.pytorch.model import *
from fastai_transformers_utils.all import *

from nmt_small.models.patch import *
from nmt_small.models.bert2gpt2 import BertEncoder, GPT2Decoder, Bert2GPT2, GeneratedBert2GPT2
from nmt_small.data.tatoeba import get_datasets
from nmt_small.metrics import compute_bleu

In [None]:
# all_slow

In [None]:
tok_data_loc = './data/tatoeba/tok_cmn.csv'
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'
enc_seq_len = 50
dec_seq_len = 40

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

## Datasets

In [None]:
small_dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len, pct=0.2)
dss = get_datasets(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len)
# len(small_dss.train), len(dss.train)

In [None]:
# dss.train[10], dss.decode(dss.train[10])

In [None]:
# dls = dss.dataloaders(bs=2)
# for x in dls.train:
#     print(x[0].shape, x[0].dtype, x[0].device, type(x[0]))
#     print(x[1].shape, x[1].dtype, x[0].device, type(x[1]))
#     print(x[2].shape, x[2].dtype, x[0].device, type(x[2]))
#     break

## Model

In [None]:
def get_model(device):
    encoder = BertEncoder(enc_model_name)
    decoder = GPT2Decoder(
        dec_model_name, dec_tokenizer.pad_token_id,
        vocab_size=len(dec_tokenizer),
        num_heads=2, drop_p=0, num_layers=2,
    )
    model = Bert2GPT2(encoder, decoder, enc_tokenizer.pad_token_id)
    model.to(device)
    return model
model = get_model('cpu')

In [None]:
freeze_to(model.encoder.layer_groups, 13)
freeze_to(model.decoder.layer_groups, -3)

In [None]:
model.summary(torch.ones((16, enc_seq_len)).long(), torch.ones((16, dec_seq_len)).long())

RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'index' in call to _th_index_select

## Learner

In [None]:
dls = small_dss.dataloaders(bs=64)
# dls = dss.dataloaders(bs=64)
learn = Learner(dls, 
                model, 
                loss_func=CrossEntropyLossFlat(ignore_index=dec_tokenizer.pad_token_id), 
                opt_func=Adam,
                metrics=[accuracy, Perplexity()],
               ).to_fp16()

In [None]:
learn.summary()

Bert2GPT2 (Input shape: ['64 x 50', '64 x 40'])
Layer (type)         Output Shape         Param #    Trainable 
Embedding            64 x 50 x 768        16,226,304 False     
________________________________________________________________
Embedding            64 x 50 x 768        393,216    False     
________________________________________________________________
Embedding            64 x 50 x 768        1,536      False     
________________________________________________________________
LayerNorm            64 x 50 x 768        1,536      False     
________________________________________________________________
Dropout              64 x 50 x 768        0          False     
________________________________________________________________
Linear               64 x 50 x 768        590,592    False     
________________________________________________________________
Linear               64 x 50 x 768        590,592    False     
__________________________________________________

In [None]:
# freeze_to(decoder.layer_groups, -3) # only train cross attention and later layer
learn.fit_one_cycle(1, 5e-4)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,6.514963,5.392271,0.024853,219.701767,00:17


## Find Hyperparams

In [None]:
class Objective():
    def __init__(self):
        self.dls = dss.dataloaders(bs=64)
        self.model = get_model(default_device())
        torch.save(self.model.state_dict(), './models/bert2gpt2_ori.pt')

    def objective(self, trial):
        lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
        self.model.load_state_dict(torch.load('./models/bert2gpt2_ori.pt'))
        learn = Learner(self.dls, 
                    self.model, 
                    loss_func=CrossEntropyLossFlat(ignore_index=dec_tokenizer.pad_token_id), 
                    opt_func=Adam,
                    metrics=[accuracy, Perplexity()],
                   ).to_fp16()
        print(f'Current trial: {trial.number} lr: {lr}')
        learn.fit_one_cycle(1, lr)
        return learn.recorder.log[4]
    
    def clear(self):
        self.model.to('cpu')
        self.dls.cpu()

In [None]:
study = optuna.create_study()
objective = Objective()
study.optimize(objective.objective, n_trials=10)

Current trial: 0 lr: 0.00012952249097646494


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,5.607232,5.438638,0.024859,230.128586,02:42


[I 2020-02-10 08:46:42,761] Finished trial#0 resulted in value: 230.1285858154297. Current best value is 230.1285858154297 with parameters: {'lr': 0.00012952249097646494}.


Current trial: 1 lr: 0.004542378374396687


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,6.320603,5.823337,0.025,338.098419,02:39


[I 2020-02-10 08:49:24,085] Finished trial#1 resulted in value: 338.0984191894531. Current best value is 230.1285858154297 with parameters: {'lr': 0.00012952249097646494}.


Current trial: 2 lr: 0.0011319755374436717


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,5.880634,5.791884,0.025,327.629669,02:42


[I 2020-02-10 08:52:08,650] Finished trial#2 resulted in value: 327.6296691894531. Current best value is 230.1285858154297 with parameters: {'lr': 0.00012952249097646494}.


Current trial: 3 lr: 0.00016100477132948145


epoch,train_loss,valid_loss,accuracy,perplexity,time


## Bleu

In [None]:
generate_args = GenerateArgs(   
    max_length=20,
#     do_sample=True,
    num_beams=3,
    temperature=1.0,
    repetition_penalty=1,
    length_penalty=1.0,
)
generated_bert2gpt2 = GeneratedBert2GPT2(bert2gpt2, enc_tokenizer, dec_tokenizer)

In [None]:
# dls = dss.dataloaders(bs=64)
compute_bleu(generated_bert2gpt2, generate_args, dec_tokenizer, dls.valid)

## Generate

In [None]:
src_strs = ['你确定？', 
            '找到汤姆。', 
            '帮帮我们吧！',
            '坚持。']
tgt_strs = ["Really?",
           "Get Tom.",
           "Help us.",
           "Hold on."]

In [None]:
result = generated_bert2gpt2.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

['IIIIIIIIIIIIIIIIII',
 'IIIIIIIIIIIIIIIIII',
 'IIIIIIIIIIIIIIIIII',
 'IIIIIIIIIIIIIIIIII']

In [None]:
src_strs = ['我很高興再次見到你。', 
            '我有點累。', 
            '我不記得寄過信了。',
            '它是我兄弟的。']
tgt_strs = ["I'm very glad to see you again.",
           "I'm a little tired.",
           "I don't remember mailing the letter.",
           "It's my brother's."]

In [None]:
result = generated_bert2gpt2.generate_from_strs(src_strs, generate_args, device='cuda:0')
result

["'t't't't't't't't't't't't't't't't't't",
 "'t't't't't't't't't't't't't't't't't't",
 "'t't't't't't't't't't't't't't't't't't",
 "'t't't't't't't't't't't't't't't't't't"]