In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType
# from transformers.utils import logging
# logging.set_verbosity_info()
import evaluate

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

SOURCE_LANG = "en"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_count, trainable_bytes = 0, 0
    total_count, total_bytes = 0, 0
    for _, param in model.named_parameters():
        total_count += param.nelement()
        total_bytes += param.nelement() * param.element_size()
        if param.requires_grad:
            trainable_count += param.nelement()
            trainable_bytes += param.nelement() * param.element_size()
    print(
        f"Total params: {total_count:12,} ({(total_bytes / 1024**2):7,.1f}MB) | "
        f"Trainable params: {trainable_count:12,} ({(trainable_bytes / 1024**2):7,.1f}MB) [{100 * trainable_count / total_count:3.1f}%]"
    )

In [None]:
modules_to_save = ["final_layer_norm", "self_attn_layer_norm", "layer_norm", "layernorm_embedding", "embed_positions"]
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "shared", "lm_head"]
config = LoraConfig(
    r=8,
    lora_alpha=8, # anything is fine (simply tune lr)
    lora_dropout=0.1,
    target_modules=target_modules,
    modules_to_save=modules_to_save, # may modify the base model
    bias="lora_only", # may modify the base model ("lora_only" or "all")
    task_type=TaskType.SEQ_2_SEQ_LM,
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

In [None]:
from utils.dataset import EnJaDatasetMaker
dataset = EnJaDatasetMaker.load_dataset(f"{SOURCE_LANG}-{TARGET_LANG}-final")
train_data = dataset["train"]
valid_data = dataset["valid"]

In [None]:
metric = evaluate.load("sacrebleu")

if TARGET_LANG == "ja":
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
        return bleu_output
else:
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
        return bleu_output

In [None]:
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 4
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = 256
    gc.min_length = 0
    gc.early_stopping = True
    # pad token is set to eos since in GPT2 pad does not exist
    gc.pad_token_id = tokenizer.eos_token_id
    gc.bos_token_id = tokenizer.bos_token_id
    gc.eos_token_id = tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name=f"{SOURCE_LANG}-{TARGET_LANG}-mBART-base",
    num_train_epochs=3,

    logging_strategy="steps",
    logging_steps=1, # * 4, 2, 1

    evaluation_strategy="steps",
    eval_steps=5_000, # * 20_000, 10_000, 5_000
    prediction_loss_only=False,
    predict_with_generate=True,
    generation_config=gen_config,

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=5_000, # * 20_000, 10_000, 5_000
    save_total_limit=20,
    load_best_model_at_end=True, # defaults to metric: "loss"
    metric_for_best_model="eval_score",
    greater_is_better=True,

    optim="adamw_torch",
    warmup_steps=875, # 3500, 1750, 875
    learning_rate=3e-5, # 3e-5, 5e-5
    bf16=True, # bf16, qint 8 ???
    
    group_by_length=True,
    length_column_name="length",

    # torch_compile=True,
    label_smoothing_factor=0.2, # 0.1, 0.2
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4, # * 1, 2, 4
    # eval_accumulation_steps=4, # ???
)

In [None]:
trainer = Seq2SeqTrainer(
    lora_model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=valid_data,
    compute_metrics=compute_metrics
)

lora_model.train()
trainer.train(resume_from_checkpoint=True)

In [None]:
lora_model.cuda()
lora_model.eval()

# train_out = trainer.predict(train_data)
valid_out = trainer.predict(valid_data)

# print("Train: ", train_out.metrics)
print("Valid: ", valid_out.metrics)

# train_decode = tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [None]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(valid_data, valid_decode, sample=3)