In [None]:
!pip install evaluate sacrebleu

In [None]:
from transformers import (AutoTokenizer, DataCollatorForSeq2Seq,
AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer)
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import notebook_login

import numpy as np
import evaluate

In [None]:
notebook_login()

In [None]:
def preprocess_function(examples):
    inputs = [PREFIX + example[SRC_LANG] for example in examples["translation"]]
    targets = [example[TGT_LANG] for example in examples["translation"]]
    model_inputs = tokenizer(inputs
                             , text_target = targets
                             , max_length = MAX_LENGTH
                             , padding = "max_length"
                             , truncation = True)
    return model_inputs


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"sacrebleu": result["score"]}

    return result

In [None]:
MODEL_CHECKPOINT = "t5-base"
SRC_LANG = "en"
TGT_LANG = "de"
PREFIX = "translate English to German: "
MAX_LENGTH = 128
NUM_EPOCHS = 5

In [None]:
t = load_dataset("facebook/flores", "eng_Latn-deu_Latn")
en = t["dev"]["sentence_eng_Latn"] + t["devtest"]["sentence_eng_Latn"]
de = t["dev"]["sentence_deu_Latn"] + t["devtest"]["sentence_deu_Latn"]
translation = [{"en": en_sentence, "de": de_sentence} for en_sentence, de_sentence in zip(en, de)]

test_data = Dataset.from_dict({
                "id": list(range(t["dev"].num_rows + t["devtest"].num_rows))
                ,"translation": translation
            })

del t, translation, en, de

In [None]:
syn_train_data = load_dataset("jaymanvirk/synthetic_parallel_corpora")
NUM_SAMPLES = syn_train_data["train"].num_rows

In [None]:
opus_train_data = load_dataset("opus_books", lang1=TGT_LANG, lang2=SRC_LANG)
opus_train_data = opus_train_data["train"].shuffle(seed = 0).select(range(NUM_SAMPLES))

In [None]:
split_data = DatasetDict({
    "train": opus_train_data
    ,"test": test_data
})

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)
model.generation_config.max_new_tokens = MAX_LENGTH

In [None]:
token_data = split_data.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer
                                       , model = MODEL_CHECKPOINT)

In [None]:
metric = evaluate.load("sacrebleu")

In [None]:
def get_trainer(output_dir = None, model = None, token_data = None
                , tokenizer = None, data_collator = None
                , compute_metrics = None):
    training_args = Seq2SeqTrainingArguments(
        output_dir = output_dir,
        evaluation_strategy = "epoch",
        learning_rate = 5e-5,
        per_device_train_batch_size = 16,
        per_device_eval_batch_size = 16,
        weight_decay = 0.01,
        save_total_limit = NUM_EPOCHS,
        num_train_epochs = NUM_EPOCHS,
        predict_with_generate = True,
        fp16 = True,
        push_to_hub = True,
        report_to="none"
    )

    trainer = Seq2SeqTrainer(
        model = model,
        args = training_args,
        train_dataset = token_data["train"],
        eval_dataset = token_data["test"],
        tokenizer = tokenizer,
        data_collator = data_collator,
        compute_metrics = compute_metrics,
    )
    
    return trainer

In [None]:
output_dir = "t5_base_fine_tuned_opus_books_en_de"
trainer = get_trainer(output_dir = output_dir, model = model, token_data = token_data
                , tokenizer = tokenizer, data_collator = data_collator
                , compute_metrics = compute_metrics)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
trainer.push_to_hub(tags="translation", commit_message="Training complete")

In [None]:
split_data = DatasetDict({
    "train": syn_train_data["train"]
    ,"test": test_data
})

In [None]:
token_data = split_data.map(preprocess_function, batched=True)

In [None]:
output_dir = "t5_base_fine_tuned_synthetic_en_de"
trainer = get_trainer(output_dir = output_dir, model = model, token_data = token_data
                , tokenizer = tokenizer, data_collator = data_collator
                , compute_metrics = compute_metrics)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
trainer.push_to_hub(tags="translation", commit_message="Training complete")