Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist IterableDataset epoch in workers #6710

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from functools import partial
from itertools import cycle, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

import numpy as np
import pyarrow as pa
Expand All @@ -26,6 +26,9 @@
from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs


if TYPE_CHECKING:
import torch

logger = get_logger(__name__)

Key = Union[int, str]
Expand Down Expand Up @@ -1188,6 +1191,18 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls):
cls.__bases__ += (torch.utils.data.IterableDataset,)


def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]:
if config.TORCH_AVAILABLE:
import torch

if isinstance(value, torch.Tensor):
return value.share_memory_()
else:
return torch.tensor(value).share_memory_()
else:
return value


class IterableDataset(DatasetInfoMixin):
"""A Dataset backed by an iterable."""

Expand Down Expand Up @@ -1220,8 +1235,8 @@ def __init__(
self._formatting = formatting
self._shuffling = shuffling
self._distributed = distributed
self._epoch = 0
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0)
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def __repr__(self):
Expand All @@ -1232,18 +1247,24 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__ = d
# Re-add torch shared memory, since shared mrmory are not always kept when pickling
self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch)
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)

def _head(self, n=5):
return _examples_to_batch(list(self.take(n)))

@property
def epoch(self) -> int:
return int(self._epoch)

def _effective_generator(self):
if self._shuffling and self._epoch == 0:
if self._shuffling and self.epoch == 0:
return self._shuffling.generator
elif self._shuffling:
# Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars)
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch
# Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars)
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch
effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
return np.random.default_rng(effective_seed)
else:
Expand Down Expand Up @@ -1836,7 +1857,7 @@ def shuffle(
)

def set_epoch(self, epoch: int):
self._epoch = epoch
self._epoch += epoch - self._epoch # update torch value in shared memory in-place

def skip(self, n) -> "IterableDataset":
"""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,28 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset):
assert len(out) == DEFAULT_N_EXAMPLES


@require_torch
def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDataset):
from torch.utils.data import DataLoader

dataset = dataset.shuffle(seed=42)
dataloader = DataLoader(dataset, num_workers=1, persistent_workers=True)
epoch0 = list(dataloader)
assert list(dataloader) == epoch0
dataset.set_epoch(1)
assert list(dataloader) != epoch0

# Make sure pickle works even with torch objects in shared memory
dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset))
dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True)
epoch1 = list(dataloader)
assert list(dataloader) == epoch1
dataset.set_epoch(2) # this should not affect the copy
assert list(dataloader) == epoch1
dataset_copy.set_epoch(2)
assert list(dataloader) != epoch1


@pytest.mark.parametrize("n", [0, 2, int(1e10)])
def test_iterable_dataset_skip(dataset: IterableDataset, n):
skip_dataset = dataset.skip(n)
Expand Down