diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 227e92fa638d91..1cdd8623e58c98 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -250,6 +250,8 @@ def _get_fsdp_ckpt_kwargs(): if TYPE_CHECKING: import optuna + if is_datasets_available(): + import datasets logger = logging.get_logger(__name__) @@ -287,7 +289,7 @@ class Trainer: The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will default to [`default_data_collator`] if no `tokenizer` is provided, an instance of [`DataCollatorWithPadding`] otherwise. - train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. @@ -296,7 +298,7 @@ class Trainer: `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. - eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. @@ -358,8 +360,8 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,