diff --git a/lhotse/dataset/dataloading.py b/lhotse/dataset/dataloading.py index edc0af4d6..04c530f01 100644 --- a/lhotse/dataset/dataloading.py +++ b/lhotse/dataset/dataloading.py @@ -1,6 +1,11 @@ import os +import random +import secrets from functools import partial -from typing import Callable, Optional +from typing import Callable, Literal, Optional, Union + +import torch +from torch import distributed as dist from lhotse.utils import fix_random_seed @@ -57,3 +62,64 @@ def worker_init_fn( # because DataLoader workers did not initialize torch.distributed. os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) + + +def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int: + """ + Resolves the special values of random seed supported in Lhotse. + + If it's an integer, we'll just return it. + + If it's "trng", we'll use the ``secrets`` module to generate a random seed + using a true RNG (to the extend supported by the OS). + + If it's "randomized", we'll check whether we're in a dataloading worker of ``torch.utils.data.DataLoader``. + If we are, we expect that it was passed the result of :func:``lhotse.dataset.dataloading.make_worker_init_fn`` + into its ``worker_init_fn`` argument, in which case we'll return a special seed exclusive to that worker. + If we are not in a dataloading worker (or ``num_workers`` was set to ``0``), we'll return Python's ``random`` + module global seed. + """ + if isinstance(seed, int): + return seed + + if seed == "randomized": + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + # not in a dataloader sub-process: get python global random seed + return random.getstate()[1][0] + else: + # in a dataloader sub-process: read out the seed we assigned to it + assert LHOTSE_PROCESS_SEED in os.environ, ( + "Requested seed='randomized' for shuffling shards differently " + "on each DataLoader node and worker, " + "but lhotse.dataset.dataloading.worker_init_fn was not called." + ) + return int(os.environ[LHOTSE_PROCESS_SEED]) + + if seed == "trng": + return secrets.randbelow(2**32) + + raise ValueError( + f"Unexpected type or value of seed: {type(seed)=} {seed=}. " + f"Supported values are: int, 'trng', and 'randomized'." + ) + + +def get_world_size() -> int: + """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank() -> int: + """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.get_rank() + else: + return 0 diff --git a/lhotse/dataset/iterable_dataset.py b/lhotse/dataset/iterable_dataset.py index 9b1a553c9..2879600ab 100644 --- a/lhotse/dataset/iterable_dataset.py +++ b/lhotse/dataset/iterable_dataset.py @@ -2,6 +2,8 @@ import torch +from lhotse import CutSet +from lhotse.dataset.dataloading import get_rank, get_world_size from lhotse.dataset.sampling.base import CutSampler @@ -93,9 +95,23 @@ def __iter__(self): def __next__(self) -> dict: try: - return self.dataset[next(self._sampler_iter)] + sampled = next(self._sampler_iter) + self._update_dataloading_info(sampled) + return self.dataset[sampled] except StopIteration: if self.auto_increment_epoch: self.set_epoch(self.epoch + 1) self._sampler_iter = None raise + + def _update_dataloading_info(self, cuts: CutSet) -> None: + rank = get_rank() + world_size = get_world_size() + for c in cuts: + # dataloading_info is attached by the sampler to each cut + # we need to update it here, because with iterable datasets + # samplers typically act as if rank=0 and world_size=1 + # and data de-duplication / per node+worker shuffling + # happens elsewhere. + c.dataloading_info["rank"] = rank + c.dataloading_info["world_size"] = world_size diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index fdc6090cb..e839f1c97 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -3,8 +3,9 @@ from copy import deepcopy from dataclasses import asdict, dataclass from math import isclose -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union +import torch from torch import distributed as dist from torch.utils.data import Sampler @@ -57,7 +58,7 @@ def __init__( drop_last: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, - seed: int = 0, + seed: Union[int, Literal["randomized", "trng"]] = 0, ) -> None: """ :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. @@ -325,6 +326,7 @@ def __next__(self): self._log_diagnostics(selected) for tfn in self._transforms: selected = tfn(selected) + attach_dataloading_info(selected, rank=self.rank, world_size=self.world_size) return selected def _log_diagnostics(self, batch: Union[CutSet, Tuple[CutSet, ...]]) -> None: @@ -347,6 +349,23 @@ def inner(cut_id: str) -> str: return inner +def attach_dataloading_info(cuts: CutSet, rank: int, world_size: int) -> None: + """ + Attaches diagnostic info about dataloading to each cut under ``dataloading_info`` custom field. + This information contains the rank, world_size, and worker_id. + If the training is not distributed, rank and world_size are 0 and 1. + If the num_workers argument in DataLoader was 0, worker_id is None. + """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + worker_id = None + else: + worker_id = worker_info.id + info = {"rank": rank, "world_size": world_size, "worker_id": worker_id} + for cut in cuts: + cut.dataloading_info = info + + @dataclass class TimeConstraint: """ diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index 2a3a3f6b2..f0cff289a 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -8,6 +8,7 @@ Generator, Iterable, List, + Literal, Optional, Tuple, Union, @@ -15,6 +16,7 @@ from lhotse import CutSet, Seconds from lhotse.cut import Cut +from lhotse.dataset.dataloading import resolve_seed from lhotse.dataset.sampling.base import ( CutSampler, EpochDiagnostics, @@ -77,7 +79,7 @@ def __init__( quadratic_duration: Optional[Seconds] = None, world_size: Optional[int] = None, rank: Optional[int] = None, - seed: int = 0, + seed: Union[int, Literal["trng", "randomized"]] = 0, strict=None, ) -> None: """ @@ -183,7 +185,8 @@ def __iter__(self) -> "DynamicCutSampler": # or we are iterating the same epoch again, in which case setting more steps # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() - self.rng = random.Random(self.seed + self.epoch) + seed = resolve_seed(self.seed) + self.rng = random.Random(seed + self.epoch) # Initiate iteration self.cuts_iter = [iter(cs) for cs in self.cuts] # Optionally shuffle @@ -193,7 +196,7 @@ def __iter__(self) -> "DynamicCutSampler": # so that they are reproducible. streaming_shuffle( cs, - rng=random.Random(self.seed + self.epoch), + rng=random.Random(seed + self.epoch), bufsize=self.shuffle_buffer_size, ) for cs in self.cuts_iter diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index 98c305ed0..caf205664 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -10,6 +10,7 @@ Generator, Iterable, List, + Literal, Optional, Sequence, Tuple, @@ -20,6 +21,7 @@ from lhotse import CutSet, Seconds from lhotse.cut import Cut +from lhotse.dataset.dataloading import resolve_seed from lhotse.dataset.sampling.base import ( CutSampler, EpochDiagnostics, @@ -27,7 +29,7 @@ TimeConstraint, ) from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter -from lhotse.utils import ifnone, streaming_shuffle +from lhotse.utils import ifnone class DynamicBucketingSampler(CutSampler): @@ -86,7 +88,7 @@ def __init__( quadratic_duration: Optional[Seconds] = None, world_size: Optional[int] = None, rank: Optional[int] = None, - seed: int = 0, + seed: Union[int, Literal["randomized", "trng"]] = 0, strict=None, shuffle_buffer_size=None, ) -> None: @@ -223,7 +225,8 @@ def _fast_forward(self): def __iter__(self) -> "DynamicBucketingSampler": if self._just_restored_state: return self - self.rng = random.Random(self.seed + self.epoch) + seed = resolve_seed(self.seed) + self.rng = random.Random(seed + self.epoch) # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps diff --git a/lhotse/dataset/sampling/stateless.py b/lhotse/dataset/sampling/stateless.py index f7e89257d..6667242b6 100644 --- a/lhotse/dataset/sampling/stateless.py +++ b/lhotse/dataset/sampling/stateless.py @@ -1,15 +1,14 @@ import logging -import os import random from pathlib import Path from typing import Callable, Dict, Generator, Iterable, Optional, Sequence, Tuple, Union import torch -import torch.distributed as dist from cytoolz import compose_left from lhotse import CutSet, Seconds from lhotse.cut.set import deserialize_cut +from lhotse.dataset.dataloading import get_rank, get_world_size from lhotse.dataset.sampling.base import SamplingDiagnostics from lhotse.lazy import Dillable from lhotse.serialization import decode_json_line @@ -314,23 +313,3 @@ def _process(self, manifest: Path, file_index: Path) -> Tuple[int]: print(offsets[-1], file=index_f) line = cuts_f.readline() return tuple(offsets) - - -def get_world_size() -> int: - """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" - if "WORLD_SIZE" in os.environ: - return int(os.environ["WORLD_SIZE"]) - if dist.is_available() and dist.is_initialized(): - return dist.get_world_size() - else: - return 1 - - -def get_rank() -> int: - """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" - if "RANK" in os.environ: - return int(os.environ["RANK"]) - elif dist.is_available() and dist.is_initialized(): - return dist.get_rank() - else: - return 0 diff --git a/lhotse/shar/readers/lazy.py b/lhotse/shar/readers/lazy.py index 1c2f5d9d8..306aaa675 100644 --- a/lhotse/shar/readers/lazy.py +++ b/lhotse/shar/readers/lazy.py @@ -17,7 +17,7 @@ import torch from lhotse.cut import Cut -from lhotse.dataset.dataloading import LHOTSE_PROCESS_SEED +from lhotse.dataset.dataloading import LHOTSE_PROCESS_SEED, resolve_seed from lhotse.lazy import ( ImitatesDict, LazyIteratorChain, @@ -226,24 +226,7 @@ def _maybe_shuffle_shards(self, shards: List) -> List: if self.shuffle_shards: shards = shards.copy() - seed = self.seed - - if seed == "randomized": - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - # not in a dataloader sub-process: get python global random seed - seed = random.getstate()[1][0] - else: - # in a dataloader sub-process: read out the seed we assigned to it - assert LHOTSE_PROCESS_SEED in os.environ, ( - "Requested seed='randomized' for shuffling shards differently " - "on each DataLoader node and worker, " - "but lhotse.dataset.dataloading.worker_init_fn was not called." - ) - seed = int(os.environ[LHOTSE_PROCESS_SEED]) - - if seed == "trng": - seed = secrets.randbelow(2**32) + seed = resolve_seed(self.seed) if self.stateful_shuffle: seed += self.epoch diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 6c6517029..820cd2ad1 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -1,6 +1,7 @@ import math import random import re +from collections import Counter from copy import deepcopy from functools import partial from math import isclose @@ -8,12 +9,15 @@ from tempfile import NamedTemporaryFile import pytest +from torch.utils.data import DataLoader from lhotse import CutSet from lhotse.dataset import ( CutConcatenate, DynamicBucketingSampler, + IterableDatasetWrapper, RoundRobinSampler, + make_worker_init_fn, report_padding_ratio_estimate, ) from lhotse.dataset.cut_transforms import concat_cuts @@ -98,6 +102,86 @@ def test_single_cut_sampler_shuffling(sampler_cls): assert [c.id for c in sampled_cuts] != [c.id for c in cut_set] +class IdentityDataset: + def __getitem__(self, item): + return item + + +@pytest.mark.parametrize("sampler_cls", [DynamicCutSampler, DynamicBucketingSampler]) +@pytest.mark.parametrize("seed", [0, "randomized", "trng"]) +def test_shuffle_seed_strategies(sampler_cls, seed): + cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) + + world_size = 2 + sampled_cuts = [] + for rank in range(world_size): + sampler = sampler_cls( + cut_set, + shuffle=True, + max_duration=10.0, + seed=seed, + rank=0, + world_size=1, + ) + dloader = DataLoader( + IterableDatasetWrapper(IdentityDataset(), sampler), + num_workers=2, + batch_size=None, + worker_init_fn=make_worker_init_fn(rank=rank, world_size=world_size), + ) + for batch in dloader: + sampled_cuts.extend(batch) + + # Since we're using 2 nodes * 2 workers, an iterable dataset, and do not do anything to de-duplicate, + # we have 4 copies of the input data. + assert len(sampled_cuts) == 4 * len(cut_set) + uniq_ids = Counter() + for c in sampled_cuts: + uniq_ids[c.id] += 1 + assert all(v == 4 for v in uniq_ids.values()) + + input_ids = list(cut_set.ids) + node0_worker0 = [ + c.id + for c in sampled_cuts + if c.dataloading_info["worker_id"] == 0 and c.dataloading_info["rank"] == 0 + ] + node0_worker1 = [ + c.id + for c in sampled_cuts + if c.dataloading_info["worker_id"] == 1 and c.dataloading_info["rank"] == 0 + ] + node1_worker0 = [ + c.id + for c in sampled_cuts + if c.dataloading_info["worker_id"] == 0 and c.dataloading_info["rank"] == 1 + ] + node1_worker1 = [ + c.id + for c in sampled_cuts + if c.dataloading_info["worker_id"] == 1 and c.dataloading_info["rank"] == 1 + ] + + if seed == 0: + # When seed=0, ensure each copy is shuffled in the same order (but different than the input). + assert node0_worker0 == node0_worker1 + assert node0_worker0 == node1_worker0 + assert node0_worker0 == node1_worker1 + assert node0_worker0 != input_ids + else: + # Otherwise, we expect each worker to shuffle in a different order. + assert node0_worker0 != node0_worker1 + assert node0_worker0 != node1_worker0 + assert node0_worker0 != node1_worker1 + assert node0_worker1 != node1_worker0 + assert node0_worker1 != node1_worker1 + assert node1_worker0 != node1_worker1 + assert node0_worker0 != input_ids + assert node0_worker1 != input_ids + assert node1_worker0 != input_ids + assert node1_worker1 != input_ids + + def test_single_cut_sampler_time_constraints(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)