Translation converts a sequence of text from one language to another. It is one of several tasks you can formulate as a sequence-to-sequence problem - a powerful framework for returning some output from an input, like translation or summarization. Translation systems are commonly used for translation between different language texts, but it can also be used for speech or some combination in between like text-to-speech or speech-to-text.

This guide shows how to:
1. Finetune T5 on the English-French subset of the OPUS Books dataset to translate English text to French.
2. Use the finetuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate sacrebleu

In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

mps_device = torch.device("mps")

# Data Load

In [None]:
books = load_dataset("opus_books", "en-fr")
books = books["train"].train_test_split(test_size=0.2)
books["train"][0]

# Preprocessing

In [None]:
# load a T5 tokenizer to process the English-French language pairs
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:

# Prefix the input with a prompt so T5 knows this is a translation task
# Some models capable of multiple NLP tasks require prompting for specific tasks.
source_lang = "en"
target_lang = "fr"
prefix = "translate English to French: "


def preprocess_function(examples):
    # Tokenize the input (English) and target (French) separately 
    # We can’t tokenize French text with a tokenizer pretrained on an English vocabulary
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    
    # Truncate sequences to be no longer than the maximum length set by the max_length parameter
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs


In [None]:
tokenized_books = books.map(preprocess_function, batched=True)

In [None]:
# Create a batch of examples using DataCollatorForSeq2Seq
# dynamically pad the sentences to the longest length in a batch during collation
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

# Training

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.to(mps_device)

In [None]:
# Define training hyperparameters in Seq2SeqTrainingArguments
# The only required parameter is output_dir which specifies where to save model
# NB: set push_to_hub=True (and sign in to Hugging Face) to upload model
# At the end of each epoch, the Trainer will evaluate the SacreBLEU metric and save the training checkpoint
training_args = Seq2SeqTrainingArguments(
    output_dir="opus_translation_model",
    evaluation_strategy="epoch", 
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True
)

# Pass the training arguments to Seq2SeqTrainer...
# along with the model, dataset, tokenizer, data collator, and compute_metrics function
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_books["train"],
    eval_dataset=tokenized_books["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Call train on the trainer object to fine-tune the model
trainer.train()

# Evaluation

In [None]:
# load the SacreBLEU metric
metric = evaluate.load("sacrebleu")

In [None]:
# create a function that passes model predictions and labels to compute() to calculate the SacreBLEU score
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result