In [1]:
from datasets import load_dataset

# Load dataset
dataset = load_dataset("textdetox/multilingual_paradetox", cache_dir="../cache")

In [7]:
from datasets import concatenate_datasets
from sklearn.model_selection import train_test_split

language_prompts = {
    "en": "translate from English to English: ",
    "ru": "translate from Russian to Russian: ",
    "uk": "translate from Ukrainian to Ukrainian: ",
    "de": "translate from German to German: ",
    "es": "translate from Spanish to Spanish: ",
    "am": "translate from Amharic to Amharic: ",
    "zh": "translate from Chinese to Chinese: ",
    "ar": "translate from Arabic to Arabic: ",
    "hi": "translate from Hindi to Hindi: ",
}

combined_datasets = {}
for lang, datasett in dataset.items():
    prompt = language_prompts[lang]
    datasett = datasett.map(lambda example: {"input_text": prompt + example["toxic_sentence"], "target_text": example["neutral_sentence"]}, remove_columns=["toxic_sentence", "neutral_sentence"])
    combined_datasets[lang] = datasett

combined_dataset = concatenate_datasets(combined_datasets.values())

datasets = combined_dataset.train_test_split(test_size=0.2, seed=42)

In [8]:
from transformers import UMT5ForConditionalGeneration, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments

model_name = "google/umt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="../cache")

In [9]:
def tokenize_function(examples):
    return  {"input_ids": tokenizer(examples["input_text"], padding='max_length', truncation=True, max_length=512)["input_ids"], "labels": tokenizer(examples["target_text"], padding='max_length', truncation=True, max_length=128)["input_ids"]}

In [10]:
tokenized_datasets = datasets.map(tokenize_function, remove_columns=["input_text", "target_text"], num_proc=4, batched=True)

In [11]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 2880
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 720
    })
})

In [None]:
tokenized_datasets.save_to_disk("tokenized_datasets")