From 2f8ed44a6d6fff59cdef3005cc19513117d18198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 28 Dec 2023 15:00:57 -0500 Subject: [PATCH] Support multiplexing with a limited number of open streams (#1248) * Support multiplexing with a limited number of open streams * Make the documentation clearer * Remove ``len`` support from infinite mux iterator * Add "trng" support to mux() as well --- lhotse/lazy.py | 176 +++++++++++++++++- lhotse/utils.py | 8 + ...bles.py => test_multiplexing_iterables.py} | 87 ++++++++- 3 files changed, 265 insertions(+), 6 deletions(-) rename test/{test_multipexing_iterables.py => test_multiplexing_iterables.py} (57%) diff --git a/lhotse/lazy.py b/lhotse/lazy.py index b7526700d..d5bd55528 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -1,8 +1,9 @@ import random +import secrets import types import warnings from functools import partial -from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union from lhotse.serialization import ( LazyMixin, @@ -11,7 +12,13 @@ extension_contains, open_best, ) -from lhotse.utils import Pathlike, fastcopy, is_module_available, streaming_shuffle +from lhotse.utils import ( + Pathlike, + build_rng, + fastcopy, + is_module_available, + streaming_shuffle, +) T = TypeVar("T") @@ -60,7 +67,8 @@ def mux( *manifests, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, - seed: int = 0, + seed: Union[int, Literal["trng"]] = 0, + max_open_streams: Optional[int] = None, ): """ Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time. @@ -82,6 +90,50 @@ def mux( ) ) + @classmethod + def infinite_mux( + cls, + *manifests, + weights: Optional[List[Union[int, float]]] = None, + seed: Union[int, Literal["trng"]] = 0, + max_open_streams: Optional[int] = None, + ): + """ + Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time. + Unlike ``mux()``, this method allows to limit the number of max open sub-iterators at any given time. + + To enable this, it performs 2-stage sampling. + First, it samples with replacement the set of iterators ``I`` to construct a subset ``I_sub`` + of size ``max_open_streams``. + Then, for each iteration step, it samples an iterator ``i`` from ``I_sub``, + fetches the next item from it, and yields it. + Once ``i`` becomes exhausted, it is replaced with a new iterator ``j`` sampled from ``I_sub``. + + .. caution:: Do not use this method with inputs that are infinitely iterable as they will + silently break the multiplexing property by only using a subset of the input iterables. + + .. caution:: This method is not recommended for multiplexing for a small amount of iterations, + as it may be much less accurate than ``mux()`` depending on the number of open streams, + iterable sizes, and the random seed. + + :param manifests: iterables to be multiplexed. + They can be either lazy or eager, but the resulting manifest will always be lazy. + :param weights: an optional weight for each iterable, affects the probability of it being sampled. + The weights are uniform by default. + If lengths are known, it makes sense to pass them here for uniform distribution of + items in the expectation. + :param seed: the random seed, ensures deterministic order across multiple iterations. + :param max_open_streams: the number of iterables that can be open simultaneously at any given time. + """ + return cls( + LazyInfiniteApproximateMultiplexer( + *manifests, + weights=weights, + seed=seed, + max_open_streams=max_open_streams, + ) + ) + def shuffle( self, rng: Optional[random.Random] = None, @@ -296,7 +348,7 @@ def __init__( *iterators: Iterable, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, - seed: int = 0, + seed: Union[int, Literal["trng"]] = 0, ) -> None: self.iterators = list(iterators) self.stop_early = stop_early @@ -314,7 +366,7 @@ def __init__( assert len(self.iterators) == len(self.weights) def __iter__(self): - rng = random.Random(self.seed) + rng = build_rng(self.seed) iters = [iter(it) for it in self.iterators] exhausted = [False for _ in range(len(iters))] @@ -348,6 +400,120 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) +class LazyInfiniteApproximateMultiplexer(ImitatesDict): + """ + A variant of :class:`.LazyIteratorMultiplexer` that allows to control the number of + iterables that are simultaneously open. + + It is useful for large-scale data sets where opening multiple file handles in + many processes leads to exhaustion of the operating system resources. + + If the data sets are sharded, it is recommended to pass each shard as a separate iterator + when creating objects of this class. It is OK to assign a dataset-level weight to each shard + (e.g., if a dataset has a weight of 0.5, assign weight 0.5 to each of its shards). + + There are several differences between this class and :class:`.LazyIteratorMultiplexer`: + * Objects of this class are infinite iterators. + * We hold a list of ``max_open_streams`` open iterators at any given time. + This list is filled by sampling input iterators with replacement. + + These differences are necessary to guarantee the weighted sampling property. + If we did not sample with replacement or make it infinite, we would simply + exhaust highly-weighted iterators towards the beginning of each "epoch" + and keep sampling only lowly-weighted iterators towards the end of each "epoch". + """ + + def __init__( + self, + *iterators: Iterable, + stop_early: bool = False, + weights: Optional[List[Union[int, float]]] = None, + seed: Union[int, Literal["trng"]] = 0, + max_open_streams: Optional[int] = None, + ) -> None: + self.iterators = list(iterators) + self.stop_early = stop_early + self.seed = seed + self.max_open_streams = max_open_streams + if max_open_streams is None or max_open_streams > len(self.iterators): + self.max_open_streams = len(self.iterators) + + assert len(self.iterators) > 0 + self.weights = weights + if weights is None: + self.weights = [1] * len(self.iterators) + assert len(self.iterators) == len(self.weights) + assert ( + self.max_open_streams is None or self.max_open_streams >= 1 + ), f"{self.max_open_streams=}" + + def __iter__(self): + """ + Assumptions + - we have N streams but can only open M at the time (M < N) + - the streams are finite + - each stream needs to be "short" to ensure the mux property + - each stream may be interpreted as a shard belonging to some larger group of streams + (e.g. multiple shards of a given dataset). + """ + rng = build_rng(self.seed) + + def shuffled_streams(): + # Create an infinite iterable of our streams. + # Assume N is "small" enough that shuffling it will be quick + # + # we need to incorporate weights into shuffling here + # and sample iterators with replacement. + # consider it0=[shard00, shard01] with weight 0.95 + # and it1=[shard10, shard11] with weight 0.05 + # so we have 4 streams [shard{01}{01}] + # if we just shuffle randomly and sample without replacement + # per each "epoch" (epoch = 4 shards) then we would have + # ignored the weights because we'll just exhaust it0 shards + # towards the beginning of an "epoch" and then keep yielding + # from it1 shards until the epoch is finished and we can sample + # from it0 again... + zipped_iter_weights = list(zip(self.iterators, self.weights)) + while True: + yield rng.choices(zipped_iter_weights, self.weights, k=1)[0] + + # Initialize an infinite sequence of finite streams. + # It is sampled with weights and replacement from ``self.iterators``, + # which are of length N. + stream_source = shuffled_streams() + + # Sample the first M active streams to be multiplexed. + # As streams get depleted, we will replace them with + # new streams sampled from the stream source. + active_streams = [] + active_weights = [] + stream_indexes = list(range(self.max_open_streams)) + for _ in range(self.max_open_streams): + sampled_stream, sampled_weight = next(stream_source) + active_streams.append(iter(sampled_stream)) + active_weights.append(sampled_weight) + + # The actual multiplexing loop. + while True: + # Select a stream from the currently active streams. + # We actually sample an index so that we know which position + # to replace if a stream is exhausted. + stream_pos = rng.choices(stream_indexes, weights=active_weights, k=1)[0] + selected = active_streams[stream_pos] + try: + # Sample from the selected stream. + item = next(selected) + yield item + except StopIteration: + # The selected stream is exhausted. Replace it with another one, + # and return a sample from the newly opened stream. + sampled_stream, sampled_weight = next(stream_source) + active_streams[stream_pos] = iter(sampled_stream) + active_weights[stream_pos] = sampled_weight + item = next(active_streams[stream_pos]) + yield item + + class LazyShuffler(ImitatesDict): """ A wrapper over an iterable that enables lazy shuffling. diff --git a/lhotse/utils.py b/lhotse/utils.py index 584d862b4..80558e0b7 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -6,6 +6,7 @@ import math import os import random +import secrets import sys import urllib import uuid @@ -1092,3 +1093,10 @@ def type_cast_value(self, ctx, value): def is_torchaudio_available() -> bool: return is_module_available("torchaudio") + + +def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random: + if seed == "trng": + return secrets.SystemRandom() + else: + return random.Random(seed) diff --git a/test/test_multipexing_iterables.py b/test/test_multiplexing_iterables.py similarity index 57% rename from test/test_multipexing_iterables.py rename to test/test_multiplexing_iterables.py index 63e2daf32..21b7f2a4f 100644 --- a/test/test_multipexing_iterables.py +++ b/test/test_multiplexing_iterables.py @@ -1,7 +1,7 @@ import pickle from lhotse import CutSet -from lhotse.lazy import LazyIteratorMultiplexer +from lhotse.lazy import LazyInfiniteApproximateMultiplexer, LazyIteratorMultiplexer from lhotse.testing.dummies import DummyManifest @@ -86,3 +86,88 @@ def test_multiplexer_with_cuts_pickling(): mux_rec = pickle.loads(data) assert list(mux) == list(mux_rec) + + +def test_multiplexer_max_open_streams(): + mux = LazyInfiniteApproximateMultiplexer( + range(3), + range(10, 13), + range(100, 103), + seed=1, + max_open_streams=2, + ) + + it = iter(mux) + samples = [] + for _ in range(9): + samples.append(next(it)) + + # Remember we are sampling with replacement when using + # max_open_streams. Here, the following streams were picked: + # stream2 [100-102], + # stream0 [0-2] + # stream0 [0-2] + # stream1 [10-12] + assert samples == [100, 0, 1, 2, 101, 102, 0, 10, 11] + + +def test_multiplexer_max_open_streams_1(): + mux = LazyInfiniteApproximateMultiplexer( + range(3), + range(10, 13), + range(100, 103), + seed=1, + max_open_streams=1, + ) + + it = iter(mux) + samples = [] + for _ in range(9): + samples.append(next(it)) + + # When max_open_streams=1, mux is reduced to a chain + assert samples == [0, 1, 2, 10, 11, 12, 0, 1, 2] + + +def test_cut_set_infinite_mux(): + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3) + cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13) + cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103) + + cuts_mux = CutSet.infinite_mux( + cuts1, cuts2, cuts3, seed=0, max_open_streams=2, weights=[0.6, 0.3, 0.1] + ) + + def cid(i: int) -> str: + return f"dummy-mono-cut-{i:04d}" + + samples = [] + for i, cut in enumerate(cuts_mux): + if i == 20: + break + samples.append(cut) + + expected_ids = ( + 10, + 11, + 10, + 12, + 11, + 0, + 1, + 12, + 2, + 10, + 0, + 1, + 2, + 100, + 11, + 12, + 101, + 0, + 1, + 2, + ) + + assert [c.id for c in samples] == [cid(i) for i in expected_ids]