From 0ae8b2c211ae0fac10c9836efdce220dc4f18fe2 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 5 Apr 2024 19:52:52 -0700 Subject: [PATCH 1/2] Add datasets.Dataset to Trainer's train_dataset and eval_dataset type hints --- src/transformers/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 436165b0e3db8..4895c7221f1d2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -287,7 +287,7 @@ class Trainer: The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will default to [`default_data_collator`] if no `tokenizer` is provided, an instance of [`DataCollatorWithPadding`] otherwise. - train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. @@ -296,7 +296,7 @@ class Trainer: `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. - eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. @@ -358,8 +358,8 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, image_processor: Optional["BaseImageProcessor"] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, From 331012d951cedb96a7d43d1414dad2fdc087bd80 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 8 Apr 2024 15:56:50 -0500 Subject: [PATCH 2/2] Add is_datasets_available check for importing datasets under TYPE_CHECKING guard https://github.com/huggingface/transformers/pull/30077/files#r1555939352 --- src/transformers/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7e30dbd1a23bd..1cdd8623e58c9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -250,6 +250,8 @@ def _get_fsdp_ckpt_kwargs(): if TYPE_CHECKING: import optuna + if is_datasets_available(): + import datasets logger = logging.get_logger(__name__)