In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from peft import PeftModel

In [None]:
# TRAIN_DATASET_NAME = "mixed-500k"
# TRAIN_DATASET_NAME = "news-250k"
TRAIN_DATASET_NAME = "mixed-250k+bt-250k"

SOURCE_LANG = "en"
# SOURCE_LANG = "ja"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
else: 
    TARGET_LANG = "en"

In [None]:
import pathlib

model_path = f"{SOURCE_LANG}-{TARGET_LANG}-{TRAIN_DATASET_NAME}"
ckp_path = f"./.ckp/{model_path}"
eval_path = f"./.eval/{model_path}"

pathlib.Path(eval_path).mkdir(parents=True, exist_ok=True)

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang=f"{SOURCE_LANG}_XX", tgt_lang=f"{TARGET_LANG}_XX")

In [None]:
from utils.dataset import EnJaDatasetMaker
dataset = EnJaDatasetMaker.load_dataset(model_path)
train_data = dataset["train"]
valid_data = dataset["valid"]

In [None]:
from utils.metric import SacreBleu
compute_metrics = SacreBleu.get_mBART_metric(tokenizer=tokenizer, target_language=TARGET_LANG)

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="none",

    prediction_loss_only=False,
    predict_with_generate=True,

    bf16=True, # bf16, qint 8 ???
    output_dir="./ckp",
    
    group_by_length=True,
    length_column_name="length",

    label_smoothing_factor=0.2, # 0.1, 0.2
    
    per_device_eval_batch_size=8,
)

In [None]:
# dictionary of datasets
datasets = {}

# add validation set if wanted
# datasets["valid"] = valid_data

In [None]:
from utils.dataset import Flores

flores_dev_data = Flores.load("dev").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

flores_dev_data = flores_dev_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

datasets["flores_dev"] = flores_dev_data

In [None]:
flores_test_data = Flores.load("dev").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

flores_test_data = flores_test_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

# datasets["flores_test"] = flores_test_data

In [None]:
from utils.dataset import WMTvat

wmt_data = WMTvat.load(f"{SOURCE_LANG}-{TARGET_LANG}").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

wmt_data = wmt_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

# datasets["wmt"] = wmt_data

## Checkpoint evaluation
Evaluate sacreBLEU score on all checkpoints saved during training.

In [None]:
import os.path
import json

# dictionary of evaluation results
eval_sets = {}

keys = datasets.keys()
# keys = ["flores_dev"]

# initialize with keys in datasets dictionary
for eval_key in keys:
    
    # load previously existing results if present
    if os.path.isfile(f"{eval_path}/{eval_key}_results.json"):
        with open(f"{eval_path}/{eval_key}_results.json") as f:
            eval_sets[eval_key] = json.load(f)
    else:
        eval_sets[eval_key] = {}

In [None]:
# define checkpoints to evaluate based on model type and source language

if TRAIN_DATASET_NAME == "mixed-500k":
    checkpoints = range(5000, 55000, 5000)
elif TRAIN_DATASET_NAME == "news-250k":
    if SOURCE_LANG == "en":
        checkpoints = range(3500, 24500, 3500)
    else:
        checkpoints = list(range(3750, 22500, 3750))
        checkpoints.append(20000)
        checkpoints.append(21250)
elif TRAIN_DATASET_NAME == "mixed-250k+bt-250k":
    checkpoints = range(2500, 27500, 2500)
else:
    raise ValueError()


# list of checkpoints to manually define. Comment/uncomment based on needs
checkpoints = [25000]

In [None]:
gen_config = {
    "max_length" : 256,
    "early_stopping" : True,
    
    "no_repeat_ngram_size" : 4,
    "length_penalty" : 1.0,
    
    "num_beams" : 5,
    # "num_beam_groups" : 5,
    # "diversity_penalty" : 0.5,
    # "do_sample" : True,
    # "penalty_alpha" : 0.6,
    # "top_k" : 4,
}

In [None]:
# toggle to keep generated predictions during evaluation
# these will not be saved to disk currently
# TODO maybe possible to save them
KEEP_PREDICTIONS = True

predictions_dict = {}

if KEEP_PREDICTIONS:
    for pred_key in keys:
        predictions_dict[pred_key] = {}

In [None]:
for i in checkpoints:

    lora_model = PeftModel.from_pretrained(model=model, model_id=f"{ckp_path}/checkpoint-{i}/")
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

    trainer = Seq2SeqTrainer(
        lora_model,
        args=train_args,
        data_collator=data_collator,
        train_dataset=train_data,
        eval_dataset=valid_data,
        compute_metrics=compute_metrics
    )

    lora_model.cuda()
    lora_model.eval()

    # evaluate on each dataset and save to file
    # inefficient to save at each checkpoint as it overwrites the file
    # but can be stopped without losing progress

    # WARNING: results are not necessarily ordered by key if partial results were loaded
    for eval_key in eval_sets.keys():

        print(f"Evaluating {eval_key} dataset...")
        
        predictions = trainer.predict(datasets[eval_key], **gen_config)
        
        if KEEP_PREDICTIONS:
            predictions_dict[eval_key][f"{i}"] = predictions.predictions

        eval_sets[eval_key][f"{i}"] = predictions.metrics

        with open(f"{eval_path}/{eval_key}_results.json", "w") as f:
            f.write(json.dumps(eval_sets[eval_key]))

    print("Checkpoint ", i, " DONE")
    del lora_model, trainer, data_collator, predictions


## Sample examination
Translate a given dataset with the chosen checkpoint to examine get sacreBLEU score and examine translation quality on samples of data.

In [None]:
checkpoint = 25000
dataset = "flores_dev"
data = datasets[dataset]

if KEEP_PREDICTIONS == False:
    predictions = predictions_dict[dataset][f"{checkpoint}"]
    metrics = eval_sets[dataset][f"{checkpoint}"]

else:
    lora_model = PeftModel.from_pretrained(model=model, model_id=f"{ckp_path}/checkpoint-{checkpoint}/")
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

    trainer = Seq2SeqTrainer(
        lora_model,
        args=train_args,
        data_collator=data_collator,
        train_dataset=train_data,
        eval_dataset=valid_data,
        compute_metrics=compute_metrics
    )
    lora_model.cuda()
    lora_model.eval()

    outputs = trainer.predict(data, **gen_config)
    predictions = outputs.predictions
    metrics = outputs.metrics

In [None]:
print("Metrics: ", metrics)
predictions_decode = tokenizer.batch_decode(predictions, skip_special_tokens=True)

In [None]:
from textwrap import wrap

def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            "\n\t\t\t".join(wrap(f"\tOriginal:  {dataset['source'][sid]}", width=100)),
            "\n\t\t\t".join(wrap(f"\tTarget:    {dataset['target'][sid]}", width=100)),
            "\n\t\t\t".join(wrap(f"\tGenerated: {generation[sid]}", width=100)), sep="\n"
        )
        print("\n")
    return

print_pairs(data, predictions_decode, sample=10)