-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added max_sample_ arguments #10551
Added max_sample_ arguments #10551
Changes from 6 commits
c2acb5d
85f2dec
2f99439
34f1b23
0474538
b9d9b6f
7cf59c4
b1f5323
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -114,6 +114,21 @@ class DataTrainingArguments: | |||||||
default=None, | ||||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, | ||||||||
) | ||||||||
max_train_samples: Optional[int] = field( | ||||||||
default=None, | ||||||||
metadata={ | ||||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | ||||||||
"value if set." | ||||||||
}, | ||||||||
) | ||||||||
max_val_samples: Optional[int] = field( | ||||||||
default=None, | ||||||||
metadata={ | ||||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " | ||||||||
"value if set." | ||||||||
}, | ||||||||
) | ||||||||
|
||||||||
block_size: Optional[int] = field( | ||||||||
default=None, | ||||||||
metadata={ | ||||||||
|
@@ -346,19 +361,37 @@ def group_texts(examples): | |||||||
# | ||||||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information: | ||||||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map | ||||||||
lm_datasets = tokenized_datasets.map( | ||||||||
group_texts, | ||||||||
batched=True, | ||||||||
num_proc=data_args.preprocessing_num_workers, | ||||||||
load_from_cache_file=not data_args.overwrite_cache, | ||||||||
) | ||||||||
|
||||||||
if training_args.do_train: | ||||||||
if "train" not in tokenized_datasets: | ||||||||
raise ValueError("--do_train requires a train dataset") | ||||||||
train_dataset = tokenized_datasets["train"].map( | ||||||||
group_texts, | ||||||||
batched=True, | ||||||||
num_proc=data_args.preprocessing_num_workers, | ||||||||
load_from_cache_file=not data_args.overwrite_cache, | ||||||||
) | ||||||||
if data_args.max_train_samples is not None: | ||||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples)) | ||||||||
|
||||||||
if training_args.do_eval: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
please add a new line between the ifs so that they don't mesh together (same in all other scripts). |
||||||||
if "validation" not in tokenized_datasets: | ||||||||
raise ValueError("--do_eval requires a validation dataset") | ||||||||
eval_dataset = tokenized_datasets["validation"].map( | ||||||||
group_texts, | ||||||||
batched=True, | ||||||||
num_proc=data_args.preprocessing_num_workers, | ||||||||
load_from_cache_file=not data_args.overwrite_cache, | ||||||||
) | ||||||||
if data_args.max_val_samples is not None: | ||||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) | ||||||||
|
||||||||
# Initialize our Trainer | ||||||||
trainer = Trainer( | ||||||||
model=model, | ||||||||
args=training_args, | ||||||||
train_dataset=lm_datasets["train"] if training_args.do_train else None, | ||||||||
eval_dataset=lm_datasets["validation"] if training_args.do_eval else None, | ||||||||
train_dataset=train_dataset if training_args.do_train else None, | ||||||||
eval_dataset=eval_dataset if training_args.do_eval else None, | ||||||||
tokenizer=tokenizer, | ||||||||
# Data collator will default to DataCollatorWithPadding, so we change it. | ||||||||
data_collator=default_data_collator, | ||||||||
|
@@ -377,24 +410,28 @@ def group_texts(examples): | |||||||
|
||||||||
metrics = train_result.metrics | ||||||||
|
||||||||
max_train_samples = ( | ||||||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) | ||||||||
) | ||||||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) | ||||||||
|
||||||||
trainer.log_metrics("train", metrics) | ||||||||
trainer.save_metrics("train", metrics) | ||||||||
trainer.save_state() | ||||||||
|
||||||||
# Evaluation | ||||||||
results = {} | ||||||||
if training_args.do_eval: | ||||||||
logger.info("*** Evaluate ***") | ||||||||
|
||||||||
eval_output = trainer.evaluate() | ||||||||
|
||||||||
perplexity = math.exp(eval_output["eval_loss"]) | ||||||||
results["perplexity"] = perplexity | ||||||||
metrics = trainer.evaluate() | ||||||||
|
||||||||
trainer.log_metrics("eval", results) | ||||||||
trainer.save_metrics("eval", results) | ||||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) | ||||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) | ||||||||
perplexity = math.exp(metrics["eval_loss"]) | ||||||||
metrics["perplexity"] = perplexity | ||||||||
|
||||||||
return results | ||||||||
trainer.log_metrics("eval", metrics) | ||||||||
trainer.save_metrics("eval", metrics) | ||||||||
|
||||||||
|
||||||||
def _mp_fn(index): | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,20 @@ class DataTrainingArguments: | |
"If False, will pad the samples dynamically when batching to the maximum length in the batch." | ||
}, | ||
) | ||
max_train_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | ||
"value if set." | ||
}, | ||
) | ||
max_val_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this " | ||
"value if set." | ||
}, | ||
) | ||
|
||
def __post_init__(self): | ||
if self.dataset_name is None and self.train_file is None and self.validation_file is None: | ||
|
@@ -358,12 +372,29 @@ def group_texts(examples): | |
# | ||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information: | ||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map | ||
tokenized_datasets = tokenized_datasets.map( | ||
group_texts, | ||
batched=True, | ||
num_proc=data_args.preprocessing_num_workers, | ||
load_from_cache_file=not data_args.overwrite_cache, | ||
) | ||
if training_args.do_train: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on this script too. |
||
if "train" not in tokenized_datasets: | ||
raise ValueError("--do_train requires a train dataset") | ||
train_dataset = tokenized_datasets["train"].map( | ||
group_texts, | ||
batched=True, | ||
num_proc=data_args.preprocessing_num_workers, | ||
load_from_cache_file=not data_args.overwrite_cache, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
if data_args.max_train_samples is not None: | ||
train_dataset = train_dataset.select(range(data_args.max_train_samples)) | ||
|
||
if training_args.do_eval: | ||
if "validation" not in tokenized_datasets: | ||
raise ValueError("--do_eval requires a validation dataset") | ||
eval_dataset = tokenized_datasets["validation"].map( | ||
group_texts, | ||
batched=True, | ||
num_proc=data_args.preprocessing_num_workers, | ||
load_from_cache_file=not data_args.overwrite_cache, | ||
) | ||
if data_args.max_val_samples is not None: | ||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) | ||
|
||
# Data collator | ||
data_collator = DataCollatorForPermutationLanguageModeling( | ||
|
@@ -376,8 +407,8 @@ def group_texts(examples): | |
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None, | ||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, | ||
train_dataset=train_dataset if training_args.do_train else None, | ||
eval_dataset=eval_dataset if training_args.do_eval else None, | ||
tokenizer=tokenizer, | ||
data_collator=data_collator, | ||
) | ||
|
@@ -394,24 +425,28 @@ def group_texts(examples): | |
trainer.save_model() # Saves the tokenizer too for easy upload | ||
metrics = train_result.metrics | ||
|
||
max_train_samples = ( | ||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) | ||
) | ||
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) | ||
|
||
trainer.log_metrics("train", metrics) | ||
trainer.save_metrics("train", metrics) | ||
trainer.save_state() | ||
|
||
# Evaluation | ||
results = {} | ||
if training_args.do_eval: | ||
logger.info("*** Evaluate ***") | ||
|
||
eval_output = trainer.evaluate() | ||
|
||
perplexity = math.exp(eval_output["eval_loss"]) | ||
results["perplexity"] = perplexity | ||
metrics = trainer.evaluate() | ||
|
||
trainer.log_metrics("eval", results) | ||
trainer.save_metrics("eval", results) | ||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) | ||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) | ||
perplexity = math.exp(metrics["eval_loss"]) | ||
metrics["perplexity"] = perplexity | ||
|
||
return results | ||
trainer.log_metrics("eval", metrics) | ||
trainer.save_metrics("eval", metrics) | ||
|
||
|
||
def _mp_fn(index): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this map can be dome has before (deleted lines 349 to 354 in the diff) since it's the same for training and validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sgugger,
so we should do it like below
and we simply select samples for train and validation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It avoids duplicating the same code this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this for almost all the examples, I thought preprocessing will be done only if it will be required.
Shall I do these changes for all the examples or mentioned here only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the other examples, you are doing the select before doing the map (to avoid preprocessing all the dataset) so it's not possible to group all the preprocessing together.I think it only applies to the three scripts in language_modeling.