In [None]:
import optuna
from fastai2.basics import *
from transformers import AutoTokenizer

from fastai2_utils.pytorch.model import *
from fastai_transformers_utils.all import *

from nmt_small.models.patch import *
from nmt_small.models.bert2gpt2 import *
from nmt_small.data.tatoeba import *
from nmt_small.metrics import compute_bleu

In [None]:
# all_slow

In [None]:
tok_data_loc = './test_data/tok_cmn.csv'
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'
enc_seq_len = 50
dec_seq_len = 40

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

# Full Test of Bert2GPT2

>

## Datasets

In [None]:
small_dss = get_tatoeba_dss(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len, pct=0.2)
dss = get_tatoeba_dss(tok_data_loc, enc_tokenizer, dec_tokenizer, enc_seq_len, dec_seq_len)
len(small_dss.train), len(dss.train)

(3392, 16964)

In [None]:
# dss.train[10], dss.decode(dss.train[10])

In [None]:
# dls = dss.dataloaders(bs=2)
# for x in dls.train:
#     print(x[0].shape, x[0].dtype, x[0].device, type(x[0]))
#     print(x[1].shape, x[1].dtype, x[0].device, type(x[1]))
#     print(x[2].shape, x[2].dtype, x[0].device, type(x[2]))
#     break

## Model

In [None]:
def get_model(device):
    encoder = BertEncoder(enc_model_name)
    decoder = GPT2Decoder(
        dec_model_name, dec_tokenizer.pad_token_id,
        vocab_size=len(dec_tokenizer),
        num_heads=2, drop_p=0, num_layers=2,
    )
    model = Bert2GPT2(encoder, decoder, enc_tokenizer.pad_token_id)
    model.to(device)
    return model
model = get_model('cpu')

## Learner

In [None]:
dls = small_dss.dataloaders(bs=16)
# dls = dss.dataloaders(bs=64)
learn = Learner(dls, 
                model, 
                loss_func=CrossEntropyLossFlat(ignore_index=dec_tokenizer.pad_token_id), 
                opt_func=Adam,
                metrics=[accuracy, Perplexity()],
               ).to_fp16()

In [None]:
# freeze_to(decoder.layer_groups, -3) # only train cross attention and later layer
learn.fit_one_cycle(1, 5e-4)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,6.631876,5.479174,0.025,239.64859,00:29


## Bleu

In [None]:
# generate_args = GenerateArgs(   
#     max_length=20,
# #     do_sample=True,
#     num_beams=3,
#     temperature=1.0,
#     repetition_penalty=1,
#     length_penalty=1.0,
# )
# generated_bert2gpt2 = GeneratedBert2GPT2(bert2gpt2, enc_tokenizer, dec_tokenizer)

In [None]:
# compute_bleu(generated_bert2gpt2, generate_args, dec_tokenizer, dls.valid)

## Generate

In [None]:
# src_strs = ['你确定？', 
#             '找到汤姆。', 
#             '帮帮我们吧！',
#             '坚持。']
# tgt_strs = ["Really?",
#            "Get Tom.",
#            "Help us.",
#            "Hold on."]
# result = generated_bert2gpt2.generate_from_strs(src_strs, generate_args, device='cuda:0')
# result

In [None]:
# src_strs = ['我很高興再次見到你。', 
#             '我有點累。', 
#             '我不記得寄過信了。',
#             '它是我兄弟的。']
# tgt_strs = ["I'm very glad to see you again.",
#            "I'm a little tired.",
#            "I don't remember mailing the letter.",
#            "It's my brother's."]
# result = generated_bert2gpt2.generate_from_strs(src_strs, generate_args, device='cuda:0')
# result

## Find Hyperparams

In [None]:
# class Objective():
#     def __init__(self):
#         self.dls = dss.dataloaders(bs=64)
#         self.model = get_model(default_device())
#         torch.save(self.model.state_dict(), './models/bert2gpt2_ori.pt')

#     def objective(self, trial):
#         lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
#         self.model.load_state_dict(torch.load('./models/bert2gpt2_ori.pt'))
#         learn = Learner(self.dls, 
#                     self.model, 
#                     loss_func=CrossEntropyLossFlat(ignore_index=dec_tokenizer.pad_token_id), 
#                     opt_func=Adam,
#                     metrics=[accuracy, Perplexity()],
#                    ).to_fp16()
#         print(f'Current trial: {trial.number} lr: {lr}')
#         learn.fit_one_cycle(1, lr)
#         return learn.recorder.log[4]
    
#     def clear(self):
#         self.model.to('cpu')
#         self.dls.cpu()

In [None]:
# study = optuna.create_study()
# objective = Objective()
# study.optimize(objective.objective, n_trials=10)