Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable seed randomization in dynamic samplers #1278

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion lhotse/dataset/dataloading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import random
import secrets
from functools import partial
from typing import Callable, Optional
from typing import Callable, Literal, Optional, Union

import torch
from torch import distributed as dist

from lhotse.utils import fix_random_seed

Expand Down Expand Up @@ -57,3 +62,64 @@
# because DataLoader workers did not initialize torch.distributed.
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)


def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int:
"""
Resolves the special values of random seed supported in Lhotse.

If it's an integer, we'll just return it.

If it's "trng", we'll use the ``secrets`` module to generate a random seed
using a true RNG (to the extend supported by the OS).

If it's "randomized", we'll check whether we're in a dataloading worker of ``torch.utils.data.DataLoader``.
If we are, we expect that it was passed the result of :func:``lhotse.dataset.dataloading.make_worker_init_fn``
into its ``worker_init_fn`` argument, in which case we'll return a special seed exclusive to that worker.
If we are not in a dataloading worker (or ``num_workers`` was set to ``0``), we'll return Python's ``random``
module global seed.
"""
if isinstance(seed, int):
return seed

if seed == "randomized":
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# not in a dataloader sub-process: get python global random seed
return random.getstate()[1][0]
else:
# in a dataloader sub-process: read out the seed we assigned to it
assert LHOTSE_PROCESS_SEED in os.environ, (

Check warning on line 92 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L92

Added line #L92 was not covered by tests
"Requested seed='randomized' for shuffling shards differently "
"on each DataLoader node and worker, "
"but lhotse.dataset.dataloading.worker_init_fn was not called."
)
return int(os.environ[LHOTSE_PROCESS_SEED])

Check warning on line 97 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L97

Added line #L97 was not covered by tests

if seed == "trng":
return secrets.randbelow(2**32)

Check warning on line 100 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L99-L100

Added lines #L99 - L100 were not covered by tests

raise ValueError(

Check warning on line 102 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L102

Added line #L102 was not covered by tests
f"Unexpected type or value of seed: {type(seed)=} {seed=}. "
f"Supported values are: int, 'trng', and 'randomized'."
)


def get_world_size() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])

Check warning on line 111 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L111

Added line #L111 was not covered by tests
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()

Check warning on line 113 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L113

Added line #L113 was not covered by tests
else:
return 1


def get_rank() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "RANK" in os.environ:
return int(os.environ["RANK"])

Check warning on line 121 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L121

Added line #L121 was not covered by tests
elif dist.is_available() and dist.is_initialized():
return dist.get_rank()

Check warning on line 123 in lhotse/dataset/dataloading.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/dataloading.py#L123

Added line #L123 was not covered by tests
else:
return 0
18 changes: 17 additions & 1 deletion lhotse/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from lhotse import CutSet
from lhotse.dataset.dataloading import get_rank, get_world_size
from lhotse.dataset.sampling.base import CutSampler


Expand Down Expand Up @@ -93,9 +95,23 @@ def __iter__(self):

def __next__(self) -> dict:
try:
return self.dataset[next(self._sampler_iter)]
sampled = next(self._sampler_iter)
self._update_dataloading_info(sampled)
return self.dataset[sampled]
except StopIteration:
if self.auto_increment_epoch:
self.set_epoch(self.epoch + 1)
self._sampler_iter = None
raise

def _update_dataloading_info(self, cuts: CutSet) -> None:
rank = get_rank()
world_size = get_world_size()
for c in cuts:
# dataloading_info is attached by the sampler to each cut
# we need to update it here, because with iterable datasets
# samplers typically act as if rank=0 and world_size=1
# and data de-duplication / per node+worker shuffling
# happens elsewhere.
c.dataloading_info["rank"] = rank
c.dataloading_info["world_size"] = world_size
23 changes: 21 additions & 2 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from math import isclose
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union

import torch
from torch import distributed as dist
from torch.utils.data import Sampler

Expand Down Expand Up @@ -57,7 +58,7 @@
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
seed: Union[int, Literal["randomized", "trng"]] = 0,
) -> None:
"""
:param shuffle: When ``True``, the cuts will be shuffled at the start of iteration.
Expand Down Expand Up @@ -325,6 +326,7 @@
self._log_diagnostics(selected)
for tfn in self._transforms:
selected = tfn(selected)
attach_dataloading_info(selected, rank=self.rank, world_size=self.world_size)
return selected

def _log_diagnostics(self, batch: Union[CutSet, Tuple[CutSet, ...]]) -> None:
Expand All @@ -347,6 +349,23 @@
return inner


def attach_dataloading_info(cuts: CutSet, rank: int, world_size: int) -> None:
"""
Attaches diagnostic info about dataloading to each cut under ``dataloading_info`` custom field.
This information contains the rank, world_size, and worker_id.
If the training is not distributed, rank and world_size are 0 and 1.
If the num_workers argument in DataLoader was 0, worker_id is None.
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
worker_id = None
else:
worker_id = worker_info.id

Check warning on line 363 in lhotse/dataset/sampling/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/base.py#L363

Added line #L363 was not covered by tests
info = {"rank": rank, "world_size": world_size, "worker_id": worker_id}
for cut in cuts:
cut.dataloading_info = info


@dataclass
class TimeConstraint:
"""
Expand Down
9 changes: 6 additions & 3 deletions lhotse/dataset/sampling/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
Generator,
Iterable,
List,
Literal,
Optional,
Tuple,
Union,
)

from lhotse import CutSet, Seconds
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import (
CutSampler,
EpochDiagnostics,
Expand Down Expand Up @@ -77,7 +79,7 @@ def __init__(
quadratic_duration: Optional[Seconds] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
seed: Union[int, Literal["trng", "randomized"]] = 0,
strict=None,
) -> None:
"""
Expand Down Expand Up @@ -183,7 +185,8 @@ def __iter__(self) -> "DynamicCutSampler":
# or we are iterating the same epoch again, in which case setting more steps
# than are actually available per epoch would have broken the checkpoint restoration.
self.diagnostics.reset_current_epoch()
self.rng = random.Random(self.seed + self.epoch)
seed = resolve_seed(self.seed)
self.rng = random.Random(seed + self.epoch)
# Initiate iteration
self.cuts_iter = [iter(cs) for cs in self.cuts]
# Optionally shuffle
Expand All @@ -193,7 +196,7 @@ def __iter__(self) -> "DynamicCutSampler":
# so that they are reproducible.
streaming_shuffle(
cs,
rng=random.Random(self.seed + self.epoch),
rng=random.Random(seed + self.epoch),
bufsize=self.shuffle_buffer_size,
)
for cs in self.cuts_iter
Expand Down
9 changes: 6 additions & 3 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Generator,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Expand All @@ -20,14 +21,15 @@

from lhotse import CutSet, Seconds
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import (
CutSampler,
EpochDiagnostics,
SamplingDiagnostics,
TimeConstraint,
)
from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter
from lhotse.utils import ifnone, streaming_shuffle
from lhotse.utils import ifnone


class DynamicBucketingSampler(CutSampler):
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(
quadratic_duration: Optional[Seconds] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
seed: Union[int, Literal["randomized", "trng"]] = 0,
strict=None,
shuffle_buffer_size=None,
) -> None:
Expand Down Expand Up @@ -223,7 +225,8 @@ def _fast_forward(self):
def __iter__(self) -> "DynamicBucketingSampler":
if self._just_restored_state:
return self
self.rng = random.Random(self.seed + self.epoch)
seed = resolve_seed(self.seed)
self.rng = random.Random(seed + self.epoch)
# Why reset the current epoch?
# Either we are iterating the epoch for the first time and it's a no-op,
# or we are iterating the same epoch again, in which case setting more steps
Expand Down
23 changes: 1 addition & 22 deletions lhotse/dataset/sampling/stateless.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import logging
import os
import random
from pathlib import Path
from typing import Callable, Dict, Generator, Iterable, Optional, Sequence, Tuple, Union

import torch
import torch.distributed as dist
from cytoolz import compose_left

from lhotse import CutSet, Seconds
from lhotse.cut.set import deserialize_cut
from lhotse.dataset.dataloading import get_rank, get_world_size
from lhotse.dataset.sampling.base import SamplingDiagnostics
from lhotse.lazy import Dillable
from lhotse.serialization import decode_json_line
Expand Down Expand Up @@ -314,23 +313,3 @@ def _process(self, manifest: Path, file_index: Path) -> Tuple[int]:
print(offsets[-1], file=index_f)
line = cuts_f.readline()
return tuple(offsets)


def get_world_size() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1


def get_rank() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
21 changes: 2 additions & 19 deletions lhotse/shar/readers/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from lhotse.cut import Cut
from lhotse.dataset.dataloading import LHOTSE_PROCESS_SEED
from lhotse.dataset.dataloading import LHOTSE_PROCESS_SEED, resolve_seed
from lhotse.lazy import (
ImitatesDict,
LazyIteratorChain,
Expand Down Expand Up @@ -226,24 +226,7 @@ def _maybe_shuffle_shards(self, shards: List) -> List:
if self.shuffle_shards:
shards = shards.copy()

seed = self.seed

if seed == "randomized":
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# not in a dataloader sub-process: get python global random seed
seed = random.getstate()[1][0]
else:
# in a dataloader sub-process: read out the seed we assigned to it
assert LHOTSE_PROCESS_SEED in os.environ, (
"Requested seed='randomized' for shuffling shards differently "
"on each DataLoader node and worker, "
"but lhotse.dataset.dataloading.worker_init_fn was not called."
)
seed = int(os.environ[LHOTSE_PROCESS_SEED])

if seed == "trng":
seed = secrets.randbelow(2**32)
seed = resolve_seed(self.seed)

if self.stateful_shuffle:
seed += self.epoch
Expand Down
Loading
Loading