In [None]:
import os
os.environ["HF_HOME"] = r"./.cache"
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    Seq2SeqTrainer, DataCollatorForSeq2Seq, GenerationConfig, Seq2SeqTrainingArguments
from peft import PeftModel
from datasets import Dataset
from utils.dataset import EnJaDatasetMaker, EnJaBackTranslation
from utils.metric import SacreBleu

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

SOURCE_LANG = "en"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [None]:
gen_config = GenerationConfig(
    no_repeat_ngram_size = 4,
    length_penalty = 1.0,
    num_beams = 3,
    max_length = 256,
    min_length = 0,
    early_stopping = True,
    # pad token is set to eos since in GPT2 pad does not exist
    pad_token_id = tokenizer.eos_token_id,
    bos_token_id = tokenizer.bos_token_id,
    eos_token_id = tokenizer.eos_token_id,
)

train_args = Seq2SeqTrainingArguments(
    report_to="none",
    # run_name=f"{SOURCE_LANG}-{TARGET_LANG}-mBART-base",
    num_train_epochs=3,

    logging_strategy="steps",
    logging_steps=1, # * 4, 2, 1

    evaluation_strategy="steps",
    eval_steps=5_000, # * 20_000, 10_000, 5_000
    prediction_loss_only=False,
    predict_with_generate=True,
    generation_config=gen_config,

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=5_000, # * 20_000, 10_000, 5_000
    save_total_limit=20,
    load_best_model_at_end=True, # defaults to metric: "loss"
    metric_for_best_model="eval_score",
    greater_is_better=True,

    optim="adamw_torch",
    warmup_steps=875, # 3500, 1750, 875
    learning_rate=3e-5, # 3e-5, 5e-5
    bf16=True, # bf16, qint 8 ???
    
    group_by_length=True,
    length_column_name="length",

    # torch_compile=True,
    label_smoothing_factor=0.2, # 0.1, 0.2
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4, # * 1, 2, 4
    # eval_accumulation_steps=4, # ???
)

In [None]:
lora_model = PeftModel.from_pretrained(model=model, model_id=r"./.ckp_en_ja_hq_news/checkpoint-21000/")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)
metric = SacreBleu.get_mBART_metric(tokenizer=tokenizer, target_language=TARGET_LANG)

# wrap for easier prediction/generation
trainer = Seq2SeqTrainer(
    lora_model,
    args=train_args,
    data_collator=data_collator,
    train_dataset=None,
    eval_dataset=None,
    compute_metrics=metric,
)

In [None]:
data : Dataset = EnJaDatasetMaker.load_dataset(f"{SOURCE_LANG}-{TARGET_LANG}-hq-news")["train"].select(range(1000))
data

In [None]:
# add ID column for consistent ordering
data = data.add_column("id", list(range(len(data))))
# sort by length for efficient dynamic padding
data = data.sort(column_names=["length", "id"])

In [None]:
EnJaBackTranslation.create_mBART_backtranslation(
    trainer, data, SOURCE_LANG, tokenizer, 
    gen_config=gen_config, chunk_size=100, out_dir="./data-bt", 
    out_name=f"{TARGET_LANG}-{SOURCE_LANG}-hq-news-mBART-bt.csv"
)