From 917fdf5acdaadd47cea359e45f17bde11cc734a4 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Sat, 23 Oct 2021 00:55:39 -0400 Subject: [PATCH] Add preprocessing num workers for GLUE --- examples/pytorch/text-classification/run_glue.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 356d0fa0b057..c6b2e56f6237 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -106,6 +106,10 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) pad_to_max_length: bool = field( default=True, metadata={ @@ -423,7 +427,12 @@ def preprocess_function(examples): result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] return result - datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) + datasets = datasets.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) if training_args.do_train: if "train" not in datasets: raise ValueError("--do_train requires a train dataset")