In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
# from transformers.utils import logging
# logging.set_verbosity_info()
import evaluate

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

SOURCE_LANG = "en"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [None]:
from utils.dataset import EnJaDatasetMaker
dataset = EnJaDatasetMaker.load_dataset(f"{SOURCE_LANG}-{TARGET_LANG}-final")
train_data = dataset["train"]
valid_data = dataset["valid"]

In [None]:
metric = evaluate.load("sacrebleu")

if TARGET_LANG == "ja":
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
        return bleu_output
else:
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
        return bleu_output

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name=f"{SOURCE_LANG}-{TARGET_LANG}-mBART-base",
    num_train_epochs=3,

    logging_strategy="steps",
    logging_steps=1, # * 4, 2, 1

    evaluation_strategy="steps",
    eval_steps=5_000, # * 20_000, 10_000, 5_000
    prediction_loss_only=False,
    predict_with_generate=True,
    generation_config=gen_config,

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=5_000, # * 20_000, 10_000, 5_000
    save_total_limit=20,
    load_best_model_at_end=True, # defaults to metric: "loss"
    metric_for_best_model="eval_score",
    greater_is_better=True,

    optim="adamw_torch",
    warmup_steps=875, # 3500, 1750, 875
    learning_rate=3e-5, # 3e-5, 5e-5
    bf16=True, # bf16, qint 8 ???
    
    group_by_length=True,
    length_column_name="length",

    # torch_compile=True,
    label_smoothing_factor=0.2, # 0.1, 0.2
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4, # * 1, 2, 4
    # eval_accumulation_steps=4, # ???
)

In [None]:
from utils.dataset import Flores

flores_dev_data = Flores.load("dev").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

flores_dev_data = flores_dev_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

In [None]:
flores_test_data = Flores.load("dev").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

flores_test_data = flores_test_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

In [None]:
from utils.dataset import WMTvat

wmt_data = WMTvat.load(f"{SOURCE_LANG}-{TARGET_LANG}").rename_columns({f"{SOURCE_LANG}_sentence": "source", f"{TARGET_LANG}_sentence": "target"})

wmt_data = wmt_data.map(
    EnJaDatasetMaker._get_map_compute_mBART_tokenization(tokenizer=tokenizer)
)

## Checkpoint evaluation
Evaluate sacreBLEU score on all checkpoints saved during training.

In [None]:
import json

flores_dev_results = {}
flores_test_results = {}
wmt_results = {}

for i in range(5000, 55000, 5000):

    lora_model = PeftModel.from_pretrained(model=model, model_id=f"./.ckp_{SOURCE_LANG}_{TARGET_LANG}/checkpoint-{i}/")
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

    trainer = Seq2SeqTrainer(
        lora_model,
        args=train_args,
        data_collator=data_collator,
        train_dataset=train_data,
        eval_dataset=valid_data,
        compute_metrics=compute_metrics
    )

    lora_model.cuda()
    lora_model.eval()

    flores_dev_results[f"checkpoint_{i}"] = trainer.predict(flores_dev_data).metrics

    flores_test_results[f"checkpoint_{i}"] = trainer.predict(flores_test_data).metrics

    wmt_results[f"checkpoint_{i}"] = trainer.predict(wmt_data).metrics

    with open(f"{SOURCE_LANG}_{TARGET_LANG}_flores_dev_results.json", "w") as f:
        f.write(json.dumps(flores_dev_results))

    with open(f"{SOURCE_LANG}_{TARGET_LANG}_flores_test_results.json", "w") as f:
        f.write(json.dumps(flores_test_results))

    with open(f"{SOURCE_LANG}_{TARGET_LANG}_wmt_results.json", "w") as f:
        f.write(json.dumps(wmt_results))

    print("Checkpoint ", i, " DONE")


## Sample examination
Translate a given dataset with the chosen checkpoint to examine get sacreBLEU score and examine translation quality on samples of data.

In [None]:
checkpoint = 50000

lora_model = PeftModel.from_pretrained(model=model, model_id=f"./.ckp_{SOURCE_LANG}_{TARGET_LANG}/checkpoint-{checkpoint}/")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

trainer = Seq2SeqTrainer(
    lora_model,
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=valid_data,
    compute_metrics=compute_metrics
)

In [None]:
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 4
    gc.length_penalty = 2.0
    gc.num_beams = 1
    # gc.max_new_tokens = 128
    gc.max_length = 256
    gc.min_length = 0
    gc.early_stopping = True
    # pad token is set to eos since in GPT2 pad does not exist
    gc.pad_token_id = tokenizer.eos_token_id
    gc.bos_token_id = tokenizer.bos_token_id
    gc.eos_token_id = tokenizer.eos_token_id
    gc.do_sample = False
    gc.penalty_alpha = 0.2
    gc.top_k = 10
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [None]:
lora_model.cuda()
lora_model.eval()

data = valid_data

predictions = trainer.predict(valid_data, gen_config)

print("Metrics: ", predictions.metrics)

predictions_decode = tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)

In [None]:
from textwrap import wrap

def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            "\n\t\t\t".join(wrap(f"\tOriginal:  {dataset['source'][sid]}", width=100)),
            "\n\t\t\t".join(wrap(f"\tTarget:    {dataset['target'][sid]}", width=100)),
            "\n\t\t\t".join(wrap(f"\tGenerated: {generation[sid]}", width=100)), sep="\n"
        )
        print("\n")
    return

print_pairs(data, predictions_decode, sample=3)