In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets = load_dataset("kde4", lang1="en", lang2="fr")

print(raw_datasets)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 210173
    })
})


In [3]:
split_datasets = raw_datasets["train"].train_test_split(train_size=0.8, seed=20)

print(split_datasets)

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 168138
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 42035
    })
})


In [4]:
split_datasets["validation"] = split_datasets.pop("test")



In [5]:
def flatten_translation(examples):
    return {
        "en": [ex["en"] for ex in examples["translation"]],
        "fr": [ex["fr"] for ex in examples["translation"]]
    }

equivalent_datasets = split_datasets.map(flatten_translation, batched=True, remove_columns=["id", "translation"])

print(equivalent_datasets)

DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 168138
    })
    validation: Dataset({
        features: ['en', 'fr'],
        num_rows: 42035
    })
})


In [6]:
model_checkpoint = "Helsinki-NLP/opus-mt-en-fr"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [12]:
en_sentence = equivalent_datasets["train"]["en"][0]
fr_sentence = equivalent_datasets["train"]["fr"][0]

inputs = tokenizer(en_sentence, text_target=fr_sentence)
print(inputs)
print(tokenizer.decode(inputs["input_ids"]))
print(tokenizer.encode(en_sentence))
print(tokenizer.decode(inputs["labels"]))

{'input_ids': [1232, 13572, 7823, 9, 0], 'attention_mask': [1, 1, 1, 1, 1], 'labels': [22181, 10691, 412, 9, 1232, 21332, 0]}
Web Shortcuts</s>
[1232, 13572, 7823, 9, 0]
Raccourcis WebComment</s>


In [None]:
max_length = 128
def preprocess_function(examples):
    inputs = examples["en"]
    targets = examples["fr"]
    model_inputs = tokenizer(inputs, text_targets=targets, max_length=max_length, truncation=True)
    return model_inputs

data_check = equivalent_datasets["train"][0:4]
tokenized_datasets_eq = equivalent_datasets.map(preprocess_function, batched=True, remove_columns=equivalent_datasets["train"].column_names)


Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). Les versions antérieures posent des problèmes lors de la création de projets de données."]} not recognized.
Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). Les versions antérieures posent des problèmes lors de la création de projets de données."]} not recognized.
Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). L

{'input_ids': [[1232, 13572, 7823, 9, 0], [35, 723, 647, 373, 45, 928, 71, 37, 4012, 9, 37, 583, 583, 3390, 3, 49, 19015, 3, 57, 309, 74, 1013, 74, 2635, 973, 529, 364, 222, 50, 3, 0], [45629, 0], [526, 602, 226, 895, 71, 1187, 251, 5049, 9, 2368, 9, 1226, 6662, 6426, 34144, 5056, 202, 8101, 1366, 288, 4933, 499, 1013, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [None]:
max_length = 128
def preprocess_function2(examples):
    inputs = [ex["en"] for ex in examples["translation"]]
    targets = [ex["fr"] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, text_targets=targets, max_length=max_length, truncation=True)
    return model_inputs

data_check = split_datasets["train"][0:4]
print(preprocess_function2(data_check))
tokenized_datasets = split_datasets.map(preprocess_function2, batched=True, remove_columns=split_datasets["train"].column_names)

Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). Les versions antérieures posent des problèmes lors de la création de projets de données."]} not recognized.
Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). Les versions antérieures posent des problèmes lors de la création de projets de données."]} not recognized.
Keyword arguments {'text_targets': ['Raccourcis WebComment', 'Téléchargez le depuis la section Fichiers (http: / /download. gna. org/ kvpnc/).', 'Texte %1', "K3b nécessite l'installation du programme « & #160; mkisofs & #160; » en version 1.14 (ou supérieure). L

{'input_ids': [[1232, 13572, 7823, 9, 0], [35, 723, 647, 373, 45, 928, 71, 37, 4012, 9, 37, 583, 583, 3390, 3, 49, 19015, 3, 57, 309, 74, 1013, 74, 2635, 973, 529, 364, 222, 50, 3, 0], [45629, 0], [526, 602, 226, 895, 71, 1187, 251, 5049, 9, 2368, 9, 1226, 6662, 6426, 34144, 5056, 202, 8101, 1366, 288, 4933, 499, 1013, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
