In [None]:
import os
os.environ["HF_HOME"] = r"./.cache"
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, \
    Seq2SeqTrainer, DataCollatorForSeq2Seq, GenerationConfig, Seq2SeqTrainingArguments
from peft import PeftModel
from datasets import Dataset
from utils.dataset import EnJaDatasetMaker, EnJaBackTranslation
from utils.metric import SacreBleu

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

SOURCE_LANG = "en"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [None]:
gen_config = {
    "max_length" : 256,
    "early_stopping" : True,
    
    "no_repeat_ngram_size" : 4,
    "length_penalty" : 1.0,
    
    "num_beams" : 5,
    # "num_beam_groups" : 5,
    # "diversity_penalty" : 0.5,
    # "do_sample" : True,
    # "penalty_alpha" : 0.6,
    # "top_k" : 4,
}

train_args = Seq2SeqTrainingArguments(
    report_to="none",

    prediction_loss_only=False,
    predict_with_generate=True,

    bf16=True,
    output_dir="./ckp",
    
    group_by_length=True,
    length_column_name="length",

    label_smoothing_factor=0.2,
    
    per_device_eval_batch_size=4,
)

In [None]:
DATASET_NAME = "mixed-250k+bt-250k"
CHECKPOINT = 25_000

data : Dataset = EnJaDatasetMaker.load_dataset(f"{SOURCE_LANG}-{TARGET_LANG}-{DATASET_NAME}")["test"]
# add ID column for consistent ordering
data = data.add_column("id", list(range(len(data))))
# sort by length for efficient dynamic padding
data = data.sort(column_names=["length", "id"])

# load and apply adapter
lora_model = PeftModel.from_pretrained(model=model,
    model_id=f"./.ckp/{SOURCE_LANG}-{TARGET_LANG}-{DATASET_NAME}/checkpoint-{CHECKPOINT}"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=lora_model)
metric = SacreBleu.get_mBART_metric(tokenizer=tokenizer, target_language=TARGET_LANG)

# wrap for easier prediction/generation
trainer = Seq2SeqTrainer(
    lora_model,
    args=train_args,
    data_collator=data_collator,
    # compute_metrics=metric,
)

In [None]:
EnJaBackTranslation.create_mBART_backtranslation(
    trainer, data, SOURCE_LANG, tokenizer, 
    gen_config=gen_config, chunk_size=1_000, out_dir="./data-bt", 
    out_name=f"{TARGET_LANG}-{SOURCE_LANG}-ckp-{CHECKPOINT}-bt.csv"
)

### Combine Datasets : train + BT

In [1]:
import os
os.environ["HF_HOME"] = r"./.cache"
from utils.dataset import EnJaDatasetSample, EnJaDatasetMaker
from transformers import MBart50TokenizerFast
from datasets import concatenate_datasets, DatasetDict

In [2]:
DATASET_NAME = "mixed-250k+bt-250k"
CHECKPOINT = 25_000
SOURCE_LANG = "ja"

if SOURCE_LANG == "en":
    TARGET_LANG = "ja"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ja_XX")
else: 
    TARGET_LANG = "en"
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="ja_XX", tgt_lang="en_XX")

In [3]:
tr_data = EnJaDatasetMaker.load_dataset(f"{SOURCE_LANG}-{TARGET_LANG}-{DATASET_NAME}")
bt_data = EnJaDatasetMaker.prepare_dataset(
    f"{SOURCE_LANG}-{TARGET_LANG}-bt-only",
    [
        EnJaDatasetSample(
            dataset=f"./data-bt/{SOURCE_LANG}-{TARGET_LANG}-ckp-{CHECKPOINT}-bt.csv", 
            nsample=300_000, ntokens=(0, 128)
        ),
    ],
    source_language = SOURCE_LANG,
    model_type= "mBART",
    tokenizer = tokenizer,
    num_proc  = 8,
    seed      = 123,
    splits    = (1, 0.002) # rescaled to 1
)

Map (num_proc=8):   0%|          | 0/273994 [00:00<?, ? examples/s]

Filter (num_proc=8):   0%|          | 0/273994 [00:00<?, ? examples/s]

sampling: using all data (273970)


Saving the dataset (0/1 shards):   0%|          | 0/273423 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/547 [00:00<?, ? examples/s]

In [4]:
full_tr_data = concatenate_datasets([tr_data["train"], bt_data["train"]])
full_va_data = concatenate_datasets([tr_data["valid"], bt_data["test"]])

In [5]:
full_data = DatasetDict({"train" : full_tr_data, "valid" : full_va_data})
full_data = full_data.shuffle(42)
full_data.save_to_disk(f"./data-fin/{SOURCE_LANG}-{TARGET_LANG}-ckp-{CHECKPOINT}-bt-500k")

Saving the dataset (0/1 shards):   0%|          | 0/546409 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1088 [00:00<?, ? examples/s]