In [46]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
import numpy as np
import torch

In [3]:
raw_datasets = load_dataset("kde4", lang1="en", lang2="fr")

print(raw_datasets)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 210173
    })
})


In [4]:
split_datasets = raw_datasets["train"].train_test_split(train_size=0.8, seed=20)

print(split_datasets)

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 168138
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 42035
    })
})


In [5]:
split_datasets["validation"] = split_datasets.pop("test")

In [6]:
def flatten_translation(examples):
    return {
        "en": [ex["en"] for ex in examples["translation"]],
        "fr": [ex["fr"] for ex in examples["translation"]]
    }

equivalent_datasets = split_datasets.map(flatten_translation, batched=True, remove_columns=["id", "translation"])

print(equivalent_datasets)

DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 168138
    })
    validation: Dataset({
        features: ['en', 'fr'],
        num_rows: 42035
    })
})


In [7]:
model_checkpoint = "Helsinki-NLP/opus-mt-en-fr"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")



In [8]:
en_sentence = equivalent_datasets["train"]["en"][0]
fr_sentence = equivalent_datasets["train"]["fr"][0]

inputs = tokenizer(en_sentence, text_target=fr_sentence)
print(inputs)
print(tokenizer.decode(inputs["input_ids"]))
print(tokenizer.encode(en_sentence))
print(tokenizer.decode(inputs["labels"]))

{'input_ids': [1232, 13572, 7823, 9, 0], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [22181, 10691, 412, 9, 1232, 21332, 0]}
Web Shortcuts</s>
[1232, 13572, 7823, 9, 0]
Raccourcis WebComment</s>


In [17]:
max_length = 128
def preprocess_function(examples):
    inputs = examples["en"]
    targets = examples["fr"]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)
    return model_inputs

# data_check = equivalent_datasets["train"][0:4]
# print(preprocess_function(data_check))
tokenized_datasets_eq = equivalent_datasets.map(preprocess_function, batched=True, remove_columns=equivalent_datasets["train"].column_names)

Map: 100%|██████████| 168138/168138 [00:44<00:00, 3768.36 examples/s]
Map: 100%|██████████| 42035/42035 [00:11<00:00, 3803.06 examples/s]


In [16]:
max_length = 128
def preprocess_function2(examples):
    inputs = [ex["en"] for ex in examples["translation"]]
    targets = [ex["fr"] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)
    return model_inputs

# data_check = split_datasets["train"][0:10]
# print(preprocess_function2(data_check))
tokenized_datasets = split_datasets.map(preprocess_function2, batched=True, remove_columns=split_datasets["train"].column_names)

Map: 100%|██████████| 168138/168138 [00:47<00:00, 3532.26 examples/s]
Map: 100%|██████████| 42035/42035 [00:11<00:00, 3538.98 examples/s]


In [20]:
max_length = 128
def preprocess_function3(examples):
    inputs = examples["en"]
    model_inputs = tokenizer(inputs, max_length=max_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["fr"], max_length=max_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets_3 = equivalent_datasets.map(preprocess_function3, batched=True, remove_columns=equivalent_datasets["train"].column_names)

Map: 100%|██████████| 168138/168138 [00:44<00:00, 3802.50 examples/s]
Map: 100%|██████████| 42035/42035 [00:10<00:00, 3924.09 examples/s]


In [21]:
print(tokenized_datasets["train"][0])
print(tokenized_datasets_eq["train"][0])
print(tokenized_datasets_3["train"][0])

{'input_ids': [1232, 13572, 7823, 9, 0], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [22181, 10691, 412, 9, 1232, 21332, 0]}
{'input_ids': [1232, 13572, 7823, 9, 0], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [22181, 10691, 412, 9, 1232, 21332, 0]}
{'input_ids': [1232, 13572, 7823, 9, 0], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [22181, 10691, 412, 9, 1232, 21332, 0]}


In [25]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [44]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(1,4)])
print(batch.keys())
# print(batch["labels"])
# print(batch["input_ids"])
# print(batch["attention_mask"])

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])


In [None]:
metric = evaluate.load_metric("sacrebleu")

def compute_metrics(eval_preds): 
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)