In [None]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from transformers.utils import logging
# logging.set_verbosity_info()
import evaluate

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

source_lng = "ja"

if source_lng == "en":
    target_lng = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    target_lng = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_count, trainable_bytes = 0, 0
    total_count, total_bytes = 0, 0
    for _, param in model.named_parameters():
        total_count += param.nelement()
        total_bytes += param.nelement() * param.element_size()
        if param.requires_grad:
            trainable_count += param.nelement()
            trainable_bytes += param.nelement() * param.element_size()
    print(
        f"Total params: {total_count:12,} ({(total_bytes / 1024**2):7,.1f}MB) | "
        f"Trainable params: {trainable_count:12,} ({(trainable_bytes / 1024**2):7,.1f}MB) [{100 * trainable_count / total_count:3.1f}%]"
    )

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

modules_to_save = ["final_layer_norm", "self_attn_layer_norm", "layer_norm", "layernorm_embedding", "embed_positions"]
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "shared"]
config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=target_modules,
    modules_to_save=modules_to_save,
    bias="lora_only"
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

In [None]:
from utils.dataset import SnowSimplified

data = SnowSimplified.load()
data = data.rename_columns({
    f"{source_lng}_sentence": "source",
    f"{target_lng}_sentence": "target"
})
train_data = data.select(range(50_000))
valid_data = data.select(range(50_000, len(data)))

In [None]:
def compute_tokenization(sample):
    inputs = tokenizer(sample["source"], text_target=sample["target"], return_tensors="pt")
    sample["length"] = inputs.input_ids.shape[1]
    sample["input_ids"] = inputs.input_ids.flatten()
    sample["attention_mask"] = inputs.attention_mask.flatten()
    sample["labels"] = inputs.labels.flatten()
    return sample

train_data = train_data.map(compute_tokenization, num_proc=1,)
valid_data = valid_data.map(compute_tokenization, num_proc=1,)

In [None]:
metric = evaluate.load("sacrebleu")

def compute_metrics(preds):
    preds_ids, labels_ids = preds

    labels_ids[labels_ids == -100] = tokenizer.eos_token_id
    references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    references = [[reference] for reference in references]

    predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

    if target_lng == "ja":
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
    else:
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
    return bleu_output

In [None]:
MAX_LENGHT = 128
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 3
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT * 2
    gc.min_length = 0
    gc.early_stopping = True
    gc.pad_token_id = tokenizer.eos_token_id
    gc.bos_token_id = tokenizer.bos_token_id
    gc.eos_token_id = tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [None]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="testing-lora-1",
    num_train_epochs=50,

    logging_strategy="steps",
    logging_steps=10,
    
    remove_unused_columns=False,

    evaluation_strategy="epoch",

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,

    # per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    
    group_by_length=True,
    # length_column_name="length",

    # per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_config=gen_config,
    # torch_compile=True,
    # label_smoothing_factor=0,
    auto_find_batch_size=True,
)

In [None]:
trainer = Seq2SeqTrainer(
    lora_model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data.remove_columns(["source", "target", "length"]), 
    eval_dataset=valid_data.remove_columns(["source", "target", "length"]), 
    compute_metrics=compute_metrics
)

In [None]:
lora_model.train()
trainer.train()

In [None]:
lora_model.cuda()
lora_model.eval()

train_out = trainer.predict(train_data.remove_columns(["source", "target", "length"]))
valid_out = trainer.predict(valid_data.remove_columns(["source", "target", "length"]))

print("Train:", train_out.metrics)
print("Valid:", valid_out.metrics)

In [None]:
train_decode = tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [None]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(valid_data, valid_decode, sample=3)