In [None]:
from datasets import load_dataset

In [2]:
raw_datasets = load_dataset("wmt/wmt19", "zh-en", trust_remote_code=True)

In [None]:
raw_datasets

In [4]:
raw_datasets["train"] = raw_datasets["train"].select(range(1000))
raw_datasets["validation"] = raw_datasets["validation"].select(range(100))

In [None]:
raw_datasets

In [None]:
raw_datasets["train"][0]

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")

In [None]:
tokenizer.source_lang, tokenizer.target_lang

In [None]:
tokenizer

In [11]:
def tokenize_fn(examples):
    inputs = [example['zh'] for example in examples['translation']]
    labels = [example['en'] for example in examples['translation']]

    model_inputs = tokenizer(
        inputs,
        text_target=labels,
        max_length = 128)

    return model_inputs


In [None]:
tokenized_datasets = raw_datasets.map(
    tokenize_fn,
    batched=True,
    remove_columns=raw_datasets["train"].column_names
)

In [13]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")

In [14]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [15]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(1, 3)])

In [None]:
batch

In [None]:
batch.keys()

In [None]:
tokenizer.decode(batch["input_ids"][0])

In [None]:
tokenizer.decode(batch["labels"][0])

In [None]:
tokenizer.decode(batch["decoder_input_ids"][0])

In [None]:
import evaluate

metric = evaluate.load("sacrebleu")

In [None]:
predictions = [
    "This plugin lets you translate web pages between several languages automatically."
]
references = [
    [
        "This plugin allows you to automatically translate web pages between several languages."
    ]
]
metric.compute(predictions=predictions, references=references)

In [24]:
import numpy as np


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

In [28]:
from transformers import Seq2SeqTrainingArguments

args = Seq2SeqTrainingArguments(
    f"finetuned-zh-to-en",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
)

In [29]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.evaluate(max_length=128)

In [None]:
trainer.train()