In [None]:
import os
os.environ["HF_HOME"] = r"./.cache"
import pathlib
import re
import json

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, \
    EncoderDecoderModel, AutoTokenizer
from tokenizers import processors
from peft import PeftModel
from utils.metric import SacreBleu
from utils.dataset import Flores, WMTvat, EnJaDatasetMaker

In [None]:
def get_eval_dataset(name, src_lang, type, encoder_tokenizer=None, decoder_tokenizer=None):
    trg_lang = "ja" if src_lang == "en" else "en"
    
    if name == "flores_dev":
        data = Flores.load("dev").rename_columns({f"{src_lang}_sentence": "source", f"{trg_lang}_sentence": "target"})
    elif name == "wmt_vat":
        data = WMTvat.load(f"{src_lang}-{trg_lang}").rename_columns({f"{src_lang}_sentence": "source", f"{trg_lang}_sentence": "target"})
    else: raise ValueError()

    if type == "mBART":
        data = data.map(
            EnJaDatasetMaker._get_map_compute_mBART_tokenization(
                tokenizer=encoder_tokenizer
            )
        )
    elif type.startswith("BERT-GPT2"):
        data = data.map(
            EnJaDatasetMaker._get_map_compute_BERT_GPT2_tokenization(
                encoder_tokenizer=encoder_tokenizer, decoder_tokenizer=decoder_tokenizer
            )
        )
    else: raise ValueError()
    return data

def get_tokenizer(type, src_lang):
    trg_lang = "ja" if src_lang == "en" else "en"
    
    if type == "mBART":
        tok = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang=f"{src_lang}_XX", tgt_lang=f"{trg_lang}_XX")
        return {
            "encoder_tokenizer": tok,
            "decoder_tokenizer": tok            
        }
    elif type.startswith("BERT-GPT2"):
        if src_lang == "en":
            encoder = "bert-base-uncased"
            decoder = "rinna/japanese-gpt2-small"
        else: # src_lang == "ja"
            encoder = "cl-tohoku/bert-base-japanese-v3"
            decoder = "gpt2"
        tok = {
            "encoder_tokenizer": AutoTokenizer.from_pretrained(encoder, use_fast=True),
            "decoder_tokenizer": AutoTokenizer.from_pretrained(decoder, use_fast=True)
        }
        if tok["decoder_tokenizer"].pad_token_id is None:
            tok["decoder_tokenizer"].pad_token_id = tok["decoder_tokenizer"].eos_token_id
        tok["decoder_tokenizer"]._tokenizer.post_processor = processors.TemplateProcessing(
            single="$A " + tok["decoder_tokenizer"].eos_token,
            special_tokens=[(tok["decoder_tokenizer"].eos_token, tok["decoder_tokenizer"].eos_token_id)],
        )
        return tok
    else: raise ValueError()

def get_base_model(type, src_lang):
    trg_lang = "ja" if src_lang == "en" else "en"
    
    if type == "mBART":
        model =  MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
    elif type == "BERT-GPT2-xattn-LoRA":
        path_to_ckp = f"./.ckp/{src_lang}-{trg_lang}-BERT-GPT2-xattn/checkpoint-{25000}"
        model = EncoderDecoderModel.from_pretrained(path_to_ckp, local_files_only=True)
    elif type == "BERT-GPT2-xattn":
        model = None
    else: raise ValueError()
    return model

In [None]:
def compute_ckp_scores(model, path_to_ckp, dataset, tokenizer, trg_lang):
    if model is not None:
        model = PeftModel.from_pretrained(model=model, model_id=path_to_ckp)
    else:
        model = EncoderDecoderModel.from_pretrained(path_to_ckp, local_files_only=True)
    
    metrics = SacreBleu.get_mBART_metric(tokenizer=tokenizer, target_language=trg_lang)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    train_args = Seq2SeqTrainingArguments(
        report_to="none",
        prediction_loss_only=False,
        predict_with_generate=True,
        bf16=True,
        group_by_length=True,
        output_dir="./.ckp",
        length_column_name="length",
        label_smoothing_factor=0.2,
        per_device_eval_batch_size=8
    )
    gen_config = {
        "max_length" : 256,
        "early_stopping" : True,
        "no_repeat_ngram_size" : 4,
        "length_penalty" : 1.0,
        "num_beams" : 5
    }

    trainer = Seq2SeqTrainer(
        model,
        args=train_args,
        data_collator=data_collator,
        compute_metrics=metrics
    )

    model.cuda()
    model.eval()
    score = trainer.predict(dataset, **gen_config).metrics
        
    return score
    

In [None]:
def compute_config_scores(type, ckp_name, src_lang, dataset_name, last_only=False):
    assert type in ["mBART", "BERT-GPT2-xattn", "BERT-GPT2-xattn-LoRA"], "invalid type"
    assert os.path.exists(f"./.ckp/{ckp_name}"), "invalid ckp id"
    assert src_lang in ["en", "ja"], "invalid language"
    assert dataset_name in ["flores_dev", "wmt_vat"], "invalid dataset"
    trg_lang = "ja" if src_lang == "en" else "en"
    
    # get all checkpoints
    ckps = []
    p_num = re.compile(".*-(.*)$")
    for fname in pathlib.Path(f"./.ckp/{ckp_name}").glob("*"):
        ckps.append(int(p_num.match(str(fname)).groups()[0]))
    ckps.sort()
    if last_only: ckps = ckps[-1]
    
    # generate tokenizer, dataset, model
    tokenizers = get_tokenizer(type, src_lang)
    model = get_base_model(type, src_lang=src_lang)
    dataset = get_eval_dataset(name=dataset_name, src_lang=src_lang, type=type, **tokenizers)
    
    path_to_save = f"./eval/{ckp_name}/{dataset_name}.json"
    if not os.path.exists(f"./eval/{ckp_name}"):
        os.makedirs(f"./eval/{ckp_name}")
    scores = {}
    if os.path.isfile(path_to_save): # if exists resume
        with open(path_to_save, "r") as fp:
            scores = json.load(fp)
    
    for ckp in ckps:
        if str(ckp) in scores:
            continue
        
        path_to_ckp = f"./.ckp/{ckp_name}/checkpoint-{ckp}"
        path_to_save = f"./eval/{ckp_name}/{dataset_name}.json"
        metrics = compute_ckp_scores(model, path_to_ckp, dataset, tokenizers["decoder_tokenizer"], trg_lang=trg_lang)
        scores[ckp] = metrics

        with open(path_to_save, "w") as fp:
            fp.write(json.dumps(scores))
    return scores

In [None]:
# scores = compute_config_scores("BERT-GPT2-xattn", "en-ja-BERT-GPT2-xattn", "en", "flores_dev", last_only=False)
# scores = compute_config_scores("BERT-GPT2-xattn-LoRA", "en-ja-BERT-GPT2-xattn-LoRA", "en", "flores_dev", last_only=False)
# scores = compute_config_scores("mBART", "en-ja-mixed-250k+bt-250k", "en", "flores_dev", last_only=False)
# scores = compute_config_scores("mBART", "mixed-500k", "en", "flores_dev", last_only=False)
# scores = compute_config_scores("mBART", "ckp-25000-bt-500k", "en", "flores_dev", last_only=False)
# scores = compute_config_scores("mBART", "news-250k", "en", "flores_dev", last_only=False)