In [1]:
from transformers import MBartModel, MBartTokenizer, MBartConfig
from transformers import AutoTokenizer
import multiprocessing
from datasets import concatenate_datasets, load_dataset, load_from_disk, DownloadMode
from tqdm import tqdm
from torch.utils import data
from typing import List, Dict

In [2]:
def tokenize(examples: List[Dict[str, str]], **kwargs):
    tokenizer: MBartTokenizer = kwargs['tokenizer']
    lang: str = kwargs['lang']
    batch_src: List[str] = [e['en'] for e in examples]
    # tokenize the batch of sentences
    tokenized_src = tokenizer(batch_src, return_special_tokens_mask=False, truncation=True,
                              max_length=tokenizer.model_max_length // 2, padding='max_length', return_tensors='pt')
    batch_src: List[str] = [e[lang] for e in examples]
    tokenized_trg = tokenizer(batch_src, return_special_tokens_mask=False, truncation=True,
                              max_length=tokenizer.model_max_length // 2, padding='max_length', return_tensors='pt')

    return {'src_ids': tokenized_src.data['input_ids'], 'trg_ids': tokenized_trg.data['input_ids']}

In [2]:
lang_pair = "en-de"
cache_dir = "/data/n.dallanoce/cc_" + lang_pair.replace("-", "_")
dataset = load_dataset("yhavinga/ccmatrix", lang_pair, split='train', cache_dir=cache_dir,
                       ignore_verifications=True)

Found cached dataset ccmatrix (/data/n.dallanoce/cc_en_de/yhavinga___ccmatrix/en-de/1.0.0/5f733aeea277b2b1bb792442ba120c0f7f4b1c7288897051bdf1e9865fe77b93)


In [4]:
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25", src_lang="en_XX")

In [6]:
trg_lng = lang_pair.split("-")[1]
dataset = dataset.map(tokenize, batched=True, input_columns=['translation'],
                      fn_kwargs={'tokenizer': tokenizer, 'lang': trg_lng}, num_proc=64)

                                                                  

#1:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#2:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#0:   0%|          | 0/3867 [00:00<?, ?ba/s]

#3:   0%|          | 0/3867 [00:00<?, ?ba/s]

     

#4:   0%|          | 0/3867 [00:00<?, ?ba/s]

#5:   0%|          | 0/3867 [00:00<?, ?ba/s]

#6:   0%|          | 0/3867 [00:00<?, ?ba/s]

#7:   0%|          | 0/3867 [00:00<?, ?ba/s]

#8:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#9:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#10:   0%|          | 0/3867 [00:00<?, ?ba/s]

#11:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#13:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#12:   0%|          | 0/3867 [00:00<?, ?ba/s]

#15:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#14:   0%|          | 0/3867 [00:00<?, ?ba/s]

#16:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#18:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#17:   0%|          | 0/3867 [00:00<?, ?ba/s]

#23:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#24:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#21:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#19:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#27:   0%|          | 0/3867 [00:00<?, ?ba/s]

#25:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#22:   0%|          | 0/3867 [00:00<?, ?ba/s]

#20:   0%|          | 0/3867 [00:00<?, ?ba/s]

#26:   0%|          | 0/3867 [00:00<?, ?ba/s]

   

#29:   0%|          | 0/3867 [00:00<?, ?ba/s]

#28:   0%|          | 0/3867 [00:00<?, ?ba/s]

#33:   0%|          | 0/3867 [00:00<?, ?ba/s]

    

#35:   0%|          | 0/3867 [00:00<?, ?ba/s]

#31:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#30:   0%|          | 0/3867 [00:00<?, ?ba/s]

#37:   0%|          | 0/3867 [00:00<?, ?ba/s]

#34:   0%|          | 0/3867 [00:00<?, ?ba/s]

#32:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#38:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#36:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#40:   0%|          | 0/3867 [00:00<?, ?ba/s]

   

#43:   0%|          | 0/3867 [00:00<?, ?ba/s]

#39:   0%|          | 0/3867 [00:00<?, ?ba/s]

#42:   0%|          | 0/3867 [00:00<?, ?ba/s]

   

#44:   0%|          | 0/3867 [00:00<?, ?ba/s]

#46:   0%|          | 0/3867 [00:00<?, ?ba/s]

#41:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#49:   0%|          | 0/3867 [00:00<?, ?ba/s]

#45:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#47:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#48:   0%|          | 0/3867 [00:00<?, ?ba/s]

#54:   0%|          | 0/3867 [00:00<?, ?ba/s]

   

#55:   0%|          | 0/3867 [00:00<?, ?ba/s]

#50:   0%|          | 0/3867 [00:00<?, ?ba/s]

#51:   0%|          | 0/3867 [00:00<?, ?ba/s]

    

#52:   0%|          | 0/3867 [00:00<?, ?ba/s]

#57:   0%|          | 0/3867 [00:00<?, ?ba/s]

#59:   0%|          | 0/3867 [00:00<?, ?ba/s]

#53:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#61:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#56:   0%|          | 0/3867 [00:00<?, ?ba/s]

#58:   0%|          | 0/3867 [00:00<?, ?ba/s]

 

#60:   0%|          | 0/3867 [00:00<?, ?ba/s]

  

#63:   0%|          | 0/3867 [00:00<?, ?ba/s]

#62:   0%|          | 0/3867 [00:00<?, ?ba/s]

In [7]:
dataset = dataset.remove_columns(['id', 'score', 'translation'])

In [8]:
save_dir = cache_dir+"_tokenized"
dataset.save_to_disk(save_dir)

Saving the dataset (0/4059 shards):   0%|          | 0/247470736 [00:00<?, ? examples/s]