From 6ad7390cc6dbc31dca384fbc6f7006e865a82529 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Sat, 2 Mar 2024 12:59:16 +0100 Subject: [PATCH 1/4] persist IterableDataset epoch in workers --- src/datasets/iterable_dataset.py | 28 ++++++++++++++++++++++------ tests/test_iterable_dataset.py | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 31329de9c31..791f46caaf4 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from functools import partial from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa @@ -26,6 +26,9 @@ from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs +if TYPE_CHECKING: + import torch + logger = get_logger(__name__) Key = Union[int, str] @@ -1188,6 +1191,15 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls): cls.__bases__ += (torch.utils.data.IterableDataset,) +def _maybe_share_with_torch_persistent_workers(value: int) -> Union[int, "torch.Tensor"]: + if config.TORCH_AVAILABLE: + import torch + + return torch.tensor(value).share_memory_() + else: + return value + + class IterableDataset(DatasetInfoMixin): """A Dataset backed by an iterable.""" @@ -1220,7 +1232,7 @@ def __init__( self._formatting = formatting self._shuffling = shuffling self._distributed = distributed - self._epoch = 0 + self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} _maybe_add_torch_iterable_dataset_parent_class(self.__class__) @@ -1238,12 +1250,16 @@ def __setstate__(self, d): def _head(self, n=5): return _examples_to_batch(list(self.take(n))) + @property + def epoch(self) -> int: + return int(self._epoch) + def _effective_generator(self): - if self._shuffling and self._epoch == 0: + if self._shuffling and self.epoch == 0: return self._shuffling.generator elif self._shuffling: - # Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars) - effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch + # Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars) + effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed return np.random.default_rng(effective_seed) else: @@ -1836,7 +1852,7 @@ def shuffle( ) def set_epoch(self, epoch: int): - self._epoch = epoch + self._epoch += epoch - self._epoch # update torch value in shared memory in-place def skip(self, n) -> "IterableDataset": """ diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index ba94aece8ae..f49dc18f406 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1493,6 +1493,21 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset): assert len(out) == DEFAULT_N_EXAMPLES +@require_torch +def test_iterable_dataset_persists_epoch_in_torch_workers(): + from torch.utils.data import DataLoader + + num_examples = 10 + num_shards = 4 + ds = Dataset.from_dict({"i": range(num_examples)}).to_iterable_dataset(num_shards=num_shards) + ds = ds.shuffle(seed=42) + dataloader = DataLoader(ds, num_workers=2, persistent_workers=True) + epoch0 = list(dataloader) + assert list(dataloader) == epoch0 + ds.set_epoch(1) + assert list(dataloader) != epoch0 + + @pytest.mark.parametrize("n", [0, 2, int(1e10)]) def test_iterable_dataset_skip(dataset: IterableDataset, n): skip_dataset = dataset.skip(n) From 4089a1fa545d993d811fd6110064e14d8d6bca21 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Sat, 2 Mar 2024 13:06:51 +0100 Subject: [PATCH 2/4] more tests --- tests/test_iterable_dataset.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index f49dc18f406..d2733335f02 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1494,19 +1494,25 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset): @require_torch -def test_iterable_dataset_persists_epoch_in_torch_workers(): +def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDataset): from torch.utils.data import DataLoader - num_examples = 10 - num_shards = 4 - ds = Dataset.from_dict({"i": range(num_examples)}).to_iterable_dataset(num_shards=num_shards) - ds = ds.shuffle(seed=42) - dataloader = DataLoader(ds, num_workers=2, persistent_workers=True) + dataset = dataset.shuffle(seed=42) + dataloader = DataLoader(dataset, num_workers=1, persistent_workers=True) epoch0 = list(dataloader) assert list(dataloader) == epoch0 - ds.set_epoch(1) + dataset.set_epoch(1) assert list(dataloader) != epoch0 + dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset)) + dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True) + epoch1 = list(dataloader) + assert list(dataloader) == epoch1 + dataset.set_epoch(2) # this should not affect the copy + assert list(dataloader) == epoch1 + dataset_copy.set_epoch(2) + assert list(dataloader) != epoch1 + @pytest.mark.parametrize("n", [0, 2, int(1e10)]) def test_iterable_dataset_skip(dataset: IterableDataset, n): From 5cb54c0219f3053b504c9944ee0d5908cc586f6c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Sat, 2 Mar 2024 13:07:37 +0100 Subject: [PATCH 3/4] comment --- tests/test_iterable_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index d2733335f02..b97913049aa 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1504,6 +1504,7 @@ def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDatas dataset.set_epoch(1) assert list(dataloader) != epoch0 + # Make sure pickle works even with torch objects in shared memory dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset)) dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True) epoch1 = list(dataloader) From 3d4cd66d5e369512954521c30e7259590b3fb256 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 5 Mar 2024 17:06:14 +0100 Subject: [PATCH 4/4] re-share memory after pickling --- src/datasets/iterable_dataset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 791f46caaf4..d21ca42fa9e 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1191,11 +1191,14 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls): cls.__bases__ += (torch.utils.data.IterableDataset,) -def _maybe_share_with_torch_persistent_workers(value: int) -> Union[int, "torch.Tensor"]: +def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]: if config.TORCH_AVAILABLE: import torch - return torch.tensor(value).share_memory_() + if isinstance(value, torch.Tensor): + return value.share_memory_() + else: + return torch.tensor(value).share_memory_() else: return value @@ -1232,8 +1235,8 @@ def __init__( self._formatting = formatting self._shuffling = shuffling self._distributed = distributed - self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} + self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) _maybe_add_torch_iterable_dataset_parent_class(self.__class__) def __repr__(self): @@ -1244,6 +1247,8 @@ def __getstate__(self): def __setstate__(self, d): self.__dict__ = d + # Re-add torch shared memory, since shared mrmory are not always kept when pickling + self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch) # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling _maybe_add_torch_iterable_dataset_parent_class(self.__class__)