diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 3d0b3ce1cf3..c23f45570b4 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from functools import partial from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union import fsspec.asyn import numpy as np @@ -26,6 +26,9 @@ from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs +if TYPE_CHECKING: + import torch + logger = get_logger(__name__) Key = Union[int, str] @@ -1690,6 +1693,18 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls): cls.__bases__ += (torch.utils.data.IterableDataset,) +def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]: + if config.TORCH_AVAILABLE: + import torch + + if isinstance(value, torch.Tensor): + return value.share_memory_() + else: + return torch.tensor(value).share_memory_() + else: + return value + + class IterableDataset(DatasetInfoMixin): """A Dataset backed by an iterable.""" @@ -1722,8 +1737,8 @@ def __init__( self._formatting = formatting self._shuffling = shuffling self._distributed = distributed - self._epoch = 0 self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} + self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) self._starting_state_dict: Optional[dict] = None self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration() self._state_dict = self._prepared_ex_iterable._init_state_dict() @@ -1841,18 +1856,24 @@ def __getstate__(self): def __setstate__(self, d): self.__dict__ = d + # Re-add torch shared memory, since shared memory is not always kept when pickling + self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch) # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling _maybe_add_torch_iterable_dataset_parent_class(self.__class__) def _head(self, n=5): return _examples_to_batch(list(self.take(n))) + @property + def epoch(self) -> int: + return int(self._epoch) + def _effective_generator(self): - if self._shuffling and self._epoch == 0: + if self._shuffling and self.epoch == 0: return self._shuffling.generator elif self._shuffling: - # Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars) - effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch + # Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars) + effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed return np.random.default_rng(effective_seed) else: @@ -2465,7 +2486,7 @@ def shuffle( ) def set_epoch(self, epoch: int): - self._epoch = epoch + self._epoch += epoch - self._epoch # update torch value in shared memory in-place def skip(self, n: int) -> "IterableDataset": """ diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 6d1f84b7d6a..8069c748396 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1641,6 +1641,28 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset): assert len(out) == DEFAULT_N_EXAMPLES +@require_torch +def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDataset): + from torch.utils.data import DataLoader + + dataset = dataset.shuffle(seed=42) + dataloader = DataLoader(dataset, num_workers=1, persistent_workers=True) + epoch0 = list(dataloader) + assert list(dataloader) == epoch0 + dataset.set_epoch(1) + assert list(dataloader) != epoch0 + + # Make sure pickle works even with torch objects in shared memory + dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset)) + dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True) + epoch1 = list(dataloader) + assert list(dataloader) == epoch1 + dataset.set_epoch(2) # this should not affect the copy + assert list(dataloader) == epoch1 + dataset_copy.set_epoch(2) + assert list(dataloader) != epoch1 + + @pytest.mark.parametrize("n", [0, 2, int(1e10)]) def test_iterable_dataset_skip(dataset: IterableDataset, n): skip_dataset = dataset.skip(n)