Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/flax/image-captioning/run_image_captioning_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ def preprocess_fn(examples, max_target_length, check_image=True):
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
Expand Down Expand Up @@ -646,7 +647,8 @@ def preprocess_fn(examples, max_target_length, check_image=True):
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
Expand Down Expand Up @@ -675,7 +677,8 @@ def preprocess_fn(examples, max_target_length, check_image=True):
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
Expand Down
6 changes: 4 additions & 2 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,16 @@ def group_texts(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
Expand Down
15 changes: 10 additions & 5 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ def prepare_train_features(examples):
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
# We will select sample from whole data if agument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create train feature from dataset
train_dataset = train_dataset.map(
prepare_train_features,
Expand All @@ -613,7 +614,8 @@ def prepare_train_features(examples):
)
if data_args.max_train_samples is not None:
# Number of samples might increase during Feature Creation, We select only specified max samples
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
processed_raw_datasets["train"] = train_dataset

# Validation preprocessing
Expand Down Expand Up @@ -669,7 +671,8 @@ def prepare_validation_features(examples):
eval_examples = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Validation Feature Creation
eval_dataset = eval_examples.map(
prepare_validation_features,
Expand All @@ -680,7 +683,8 @@ def prepare_validation_features(examples):
)
if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
processed_raw_datasets["validation"] = eval_dataset

if training_args.do_predict:
Expand All @@ -700,7 +704,8 @@ def prepare_validation_features(examples):
)
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
processed_raw_datasets["test"] = predict_dataset
# endregion

Expand Down
9 changes: 6 additions & 3 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ def preprocess_function(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand All @@ -563,7 +564,8 @@ def preprocess_function(examples):
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
Expand All @@ -579,7 +581,8 @@ def preprocess_function(examples):
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
Expand Down
9 changes: 6 additions & 3 deletions examples/pytorch/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ def filter_corrupt_images(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

train_dataset = train_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
Expand All @@ -426,7 +427,8 @@ def filter_corrupt_images(examples):
raise ValueError("--do_eval requires a train validation")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

eval_dataset = eval_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
Expand All @@ -448,7 +450,8 @@ def filter_corrupt_images(examples):
raise ValueError("--do_predict requires a test dataset")
test_dataset = dataset["test"]
if data_args.max_eval_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
test_dataset = test_dataset.select(range(max_eval_samples))

test_dataset = test_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,16 @@ def group_texts(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,16 @@ def group_texts(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/language-modeling/run_plm.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,16 @@ def group_texts(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

# Data collator
data_collator = DataCollatorForPermutationLanguageModeling(
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/multiple-choice/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def preprocess_function(examples):
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
Expand All @@ -366,7 +367,8 @@ def preprocess_function(examples):
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
Expand Down
15 changes: 10 additions & 5 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ def prepare_train_features(examples):
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
# We will select sample from whole data if argument is specified
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create train feature from dataset
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
Expand All @@ -434,7 +435,8 @@ def prepare_train_features(examples):
)
if data_args.max_train_samples is not None:
# Number of samples might increase during Feature Creation, We select only specified max samples
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

# Validation preprocessing
def prepare_validation_features(examples):
Expand Down Expand Up @@ -489,7 +491,8 @@ def prepare_validation_features(examples):
eval_examples = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
# We will select sample from whole data
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map(
Expand All @@ -502,7 +505,8 @@ def prepare_validation_features(examples):
)
if data_args.max_eval_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

if training_args.do_predict:
if "test" not in raw_datasets:
Expand All @@ -523,7 +527,8 @@ def prepare_validation_features(examples):
)
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))

# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
Expand Down
15 changes: 10 additions & 5 deletions examples/pytorch/question-answering/run_qa_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def prepare_train_features(examples):
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
# Select samples from Dataset, This will help to decrease processing time
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# Create Training Features
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
Expand All @@ -445,7 +446,8 @@ def prepare_train_features(examples):
)
if data_args.max_train_samples is not None:
# Select samples from dataset again since Feature Creation might increase number of features
train_dataset = train_dataset.select(range(data_args.max_train_samples))
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))

# Validation preprocessing
def prepare_validation_features(examples):
Expand Down Expand Up @@ -519,7 +521,8 @@ def prepare_validation_features(examples):
eval_examples = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
# Selecting Eval Samples from Dataset
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
eval_examples = eval_examples.select(range(max_eval_samples))
# Create Features from Eval Dataset
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map(
Expand All @@ -532,7 +535,8 @@ def prepare_validation_features(examples):
)
if data_args.max_eval_samples is not None:
# Selecting Samples from Dataset again since Feature Creation might increase samples size
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))

if training_args.do_predict:
if "test" not in raw_datasets:
Expand All @@ -553,7 +557,8 @@ def prepare_validation_features(examples):
)
if data_args.max_predict_samples is not None:
# During Feature creation dataset samples might increase, we will select required samples again
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))

# Data collator
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
Expand Down
Loading