diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 356d0fa0b057..c6b2e56f6237 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -106,6 +106,10 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) pad_to_max_length: bool = field( default=True, metadata={ @@ -423,7 +427,12 @@ def preprocess_function(examples): result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] return result - datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) + datasets = datasets.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) if training_args.do_train: if "train" not in datasets: raise ValueError("--do_train requires a train dataset")