In [1]:
import random
import os
os.environ["HF_HOME"] = r"./.cache"

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    GenerationConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType
# from transformers.utils import logging
# logging.set_verbosity_info()
import evaluate

In [2]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

SOURCE_LANG = "en"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [3]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_count, trainable_bytes = 0, 0
    total_count, total_bytes = 0, 0
    for _, param in model.named_parameters():
        total_count += param.nelement()
        total_bytes += param.nelement() * param.element_size()
        if param.requires_grad:
            trainable_count += param.nelement()
            trainable_bytes += param.nelement() * param.element_size()
    print(
        f"Total params: {total_count:12,} ({(total_bytes / 1024**2):7,.1f}MB) | "
        f"Trainable params: {trainable_count:12,} ({(trainable_bytes / 1024**2):7,.1f}MB) [{100 * trainable_count / total_count:3.1f}%]"
    )

In [4]:
modules_to_save = ["final_layer_norm", "self_attn_layer_norm", "layer_norm", "layernorm_embedding", "embed_positions"]
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "shared", "lm_head"]
config = LoraConfig(
    r=8,
    lora_alpha=8, # anything is fine (simply tune lr)
    lora_dropout=0.1,
    target_modules=target_modules,
    modules_to_save=modules_to_save, # may modify the base model
    bias="all", # may modify the base model ("lora_only" or "all")
    task_type=TaskType.SEQ_2_SEQ_LM,
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)

Total params:  621,454,432 (2,370.7MB) | Trainable params:   13,077,600 (   49.9MB) [2.1%]


In [5]:
from utils.dataset import EnJaDatasetMaker, EnJaDatasetSample, WikiCorpus
dataset = EnJaDatasetMaker.prepare_dataset(
    "wiki-corpus-test-1",
    [
        # lower is inclusive, upper is exclusive (0, 32) -> [0, 31]
        EnJaDatasetSample(WikiCorpus, 1000, (64, 128)),
    ],
    source_language=SOURCE_LANG,
    model_type="mBART",
    tokenizer=tokenizer,
    num_proc=8,
    seed=42
)
train_data = dataset.select(range(700))
valid_data = dataset.select(range(700, 1000))

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

skipped: loaded dataset with id="wiki-corpus-test-1" from existing cache.


In [6]:
metric = evaluate.load("sacrebleu")

if TARGET_LANG == "ja":
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)

        bleu_output = metric.compute(
            references=references, 
            predictions=predictions, 
            tokenize="ja-mecab"
        )
        return bleu_output
else:
    def compute_metrics(preds):
        preds_ids, labels_ids = preds

        labels_ids[labels_ids == -100] = tokenizer.eos_token_id
        references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        references = [[reference] for reference in references]

        predictions = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
        
        bleu_output = metric.compute(
            references=references, 
            predictions=predictions
        )
        return bleu_output

In [7]:
MAX_LENGHT = 256
def set_decoder_configuration(gc: GenerationConfig):
    gc.no_repeat_ngram_size = 4
    gc.length_penalty = 2.0
    gc.num_beams = 3
    #gen_config.max_new_tokens = MAX_LENGHT
    gc.max_length = MAX_LENGHT
    gc.min_length = 0
    gc.early_stopping = True
    # pad token is set to eos since in GPT2 pad does not exist
    gc.pad_token_id = tokenizer.eos_token_id
    gc.bos_token_id = tokenizer.bos_token_id
    gc.eos_token_id = tokenizer.eos_token_id
    return gc

gen_config = GenerationConfig()
gen_config = set_decoder_configuration(gen_config)

In [8]:
train_args = Seq2SeqTrainingArguments(
    report_to="wandb",
    run_name="testing-lora-1",
    num_train_epochs=4,

    logging_strategy="steps",
    logging_steps=5,
    
    # remove_unused_columns=False,

    evaluation_strategy="steps",
    eval_steps=50,
    
    predict_with_generate=True,
    # include_inputs_for_metrics="True"
    generation_config=gen_config,

    output_dir="./.ckp/",
    save_strategy="steps",
    save_steps=10_000,
    save_total_limit=4,

    optim="adamw_torch",
    bf16=True,
    
    group_by_length=True,
    length_column_name="length",

    # torch_compile=True,
    label_smoothing_factor=0.1,
    
    auto_find_batch_size=True,
    # per_device_train_batch_size=8,
    # per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    # eval_accumulation_steps=4,
)

In [9]:
trainer = Seq2SeqTrainer(
    lora_model, 
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=valid_data,
    compute_metrics=compute_metrics
)

In [10]:
lora_model.train()
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavidboening[0m ([33mdandd[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/176 [00:00<?, ?it/s]

You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
lora_model.cuda()
lora_model.eval()

train_out = trainer.predict(train_data)
valid_out = trainer.predict(valid_data)

print("Train:", train_out.metrics)
print("Valid:", valid_out.metrics)

In [None]:
train_decode = tokenizer.batch_decode(train_out.predictions, skip_special_tokens=True)
valid_decode = tokenizer.batch_decode(valid_out.predictions, skip_special_tokens=True)

In [None]:
def print_pairs(dataset, generation, sample=5):
    assert len(dataset) == len(generation), "Invalid combination!"

    sample_ids = random.sample(range(len(dataset)), sample)
    for i, sid in enumerate(sample_ids):
        print(f"Sentence #{i} [id={sid}]")
        print(
            f"\tOriginal:  {dataset['source'][sid]}\n"
            f"\tTarget:    {dataset['target'][sid]}\n"
            f"\tGenerated: {generation[sid]}\n"
        )
    return

print_pairs(valid_data, valid_decode, sample=3)