Skip to content

Commit

Permalink
Support multiplexing with a limited number of open streams (#1248)
Browse files Browse the repository at this point in the history
* Support multiplexing with a limited number of open streams

* Make the documentation clearer

* Remove ``len`` support from infinite mux iterator

* Add "trng" support to mux() as well
  • Loading branch information
pzelasko committed Dec 28, 2023
1 parent 8b3dac2 commit 2f8ed44
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 6 deletions.
176 changes: 171 additions & 5 deletions lhotse/lazy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import random
import secrets
import types
import warnings
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union
from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union

from lhotse.serialization import (
LazyMixin,
Expand All @@ -11,7 +12,13 @@
extension_contains,
open_best,
)
from lhotse.utils import Pathlike, fastcopy, is_module_available, streaming_shuffle
from lhotse.utils import (
Pathlike,
build_rng,
fastcopy,
is_module_available,
streaming_shuffle,
)

T = TypeVar("T")

Expand Down Expand Up @@ -60,7 +67,8 @@ def mux(
*manifests,
stop_early: bool = False,
weights: Optional[List[Union[int, float]]] = None,
seed: int = 0,
seed: Union[int, Literal["trng"]] = 0,
max_open_streams: Optional[int] = None,
):
"""
Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
Expand All @@ -82,6 +90,50 @@ def mux(
)
)

@classmethod
def infinite_mux(
cls,
*manifests,
weights: Optional[List[Union[int, float]]] = None,
seed: Union[int, Literal["trng"]] = 0,
max_open_streams: Optional[int] = None,
):
"""
Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
Unlike ``mux()``, this method allows to limit the number of max open sub-iterators at any given time.
To enable this, it performs 2-stage sampling.
First, it samples with replacement the set of iterators ``I`` to construct a subset ``I_sub``
of size ``max_open_streams``.
Then, for each iteration step, it samples an iterator ``i`` from ``I_sub``,
fetches the next item from it, and yields it.
Once ``i`` becomes exhausted, it is replaced with a new iterator ``j`` sampled from ``I_sub``.
.. caution:: Do not use this method with inputs that are infinitely iterable as they will
silently break the multiplexing property by only using a subset of the input iterables.
.. caution:: This method is not recommended for multiplexing for a small amount of iterations,
as it may be much less accurate than ``mux()`` depending on the number of open streams,
iterable sizes, and the random seed.
:param manifests: iterables to be multiplexed.
They can be either lazy or eager, but the resulting manifest will always be lazy.
:param weights: an optional weight for each iterable, affects the probability of it being sampled.
The weights are uniform by default.
If lengths are known, it makes sense to pass them here for uniform distribution of
items in the expectation.
:param seed: the random seed, ensures deterministic order across multiple iterations.
:param max_open_streams: the number of iterables that can be open simultaneously at any given time.
"""
return cls(
LazyInfiniteApproximateMultiplexer(
*manifests,
weights=weights,
seed=seed,
max_open_streams=max_open_streams,
)
)

def shuffle(
self,
rng: Optional[random.Random] = None,
Expand Down Expand Up @@ -296,7 +348,7 @@ def __init__(
*iterators: Iterable,
stop_early: bool = False,
weights: Optional[List[Union[int, float]]] = None,
seed: int = 0,
seed: Union[int, Literal["trng"]] = 0,
) -> None:
self.iterators = list(iterators)
self.stop_early = stop_early
Expand All @@ -314,7 +366,7 @@ def __init__(
assert len(self.iterators) == len(self.weights)

def __iter__(self):
rng = random.Random(self.seed)
rng = build_rng(self.seed)
iters = [iter(it) for it in self.iterators]
exhausted = [False for _ in range(len(iters))]

Expand Down Expand Up @@ -348,6 +400,120 @@ def __add__(self, other) -> "LazyIteratorChain":
return LazyIteratorChain(self, other)


class LazyInfiniteApproximateMultiplexer(ImitatesDict):
"""
A variant of :class:`.LazyIteratorMultiplexer` that allows to control the number of
iterables that are simultaneously open.
It is useful for large-scale data sets where opening multiple file handles in
many processes leads to exhaustion of the operating system resources.
If the data sets are sharded, it is recommended to pass each shard as a separate iterator
when creating objects of this class. It is OK to assign a dataset-level weight to each shard
(e.g., if a dataset has a weight of 0.5, assign weight 0.5 to each of its shards).
There are several differences between this class and :class:`.LazyIteratorMultiplexer`:
* Objects of this class are infinite iterators.
* We hold a list of ``max_open_streams`` open iterators at any given time.
This list is filled by sampling input iterators with replacement.
These differences are necessary to guarantee the weighted sampling property.
If we did not sample with replacement or make it infinite, we would simply
exhaust highly-weighted iterators towards the beginning of each "epoch"
and keep sampling only lowly-weighted iterators towards the end of each "epoch".
"""

def __init__(
self,
*iterators: Iterable,
stop_early: bool = False,
weights: Optional[List[Union[int, float]]] = None,
seed: Union[int, Literal["trng"]] = 0,
max_open_streams: Optional[int] = None,
) -> None:
self.iterators = list(iterators)
self.stop_early = stop_early
self.seed = seed
self.max_open_streams = max_open_streams
if max_open_streams is None or max_open_streams > len(self.iterators):
self.max_open_streams = len(self.iterators)

assert len(self.iterators) > 0
self.weights = weights
if weights is None:
self.weights = [1] * len(self.iterators)
assert len(self.iterators) == len(self.weights)
assert (
self.max_open_streams is None or self.max_open_streams >= 1
), f"{self.max_open_streams=}"

def __iter__(self):
"""
Assumptions
- we have N streams but can only open M at the time (M < N)
- the streams are finite
- each stream needs to be "short" to ensure the mux property
- each stream may be interpreted as a shard belonging to some larger group of streams
(e.g. multiple shards of a given dataset).
"""
rng = build_rng(self.seed)

def shuffled_streams():
# Create an infinite iterable of our streams.
# Assume N is "small" enough that shuffling it will be quick
#
# we need to incorporate weights into shuffling here
# and sample iterators with replacement.
# consider it0=[shard00, shard01] with weight 0.95
# and it1=[shard10, shard11] with weight 0.05
# so we have 4 streams [shard{01}{01}]
# if we just shuffle randomly and sample without replacement
# per each "epoch" (epoch = 4 shards) then we would have
# ignored the weights because we'll just exhaust it0 shards
# towards the beginning of an "epoch" and then keep yielding
# from it1 shards until the epoch is finished and we can sample
# from it0 again...
zipped_iter_weights = list(zip(self.iterators, self.weights))
while True:
yield rng.choices(zipped_iter_weights, self.weights, k=1)[0]

# Initialize an infinite sequence of finite streams.
# It is sampled with weights and replacement from ``self.iterators``,
# which are of length N.
stream_source = shuffled_streams()

# Sample the first M active streams to be multiplexed.
# As streams get depleted, we will replace them with
# new streams sampled from the stream source.
active_streams = []
active_weights = []
stream_indexes = list(range(self.max_open_streams))
for _ in range(self.max_open_streams):
sampled_stream, sampled_weight = next(stream_source)
active_streams.append(iter(sampled_stream))
active_weights.append(sampled_weight)

# The actual multiplexing loop.
while True:
# Select a stream from the currently active streams.
# We actually sample an index so that we know which position
# to replace if a stream is exhausted.
stream_pos = rng.choices(stream_indexes, weights=active_weights, k=1)[0]
selected = active_streams[stream_pos]
try:
# Sample from the selected stream.
item = next(selected)
yield item
except StopIteration:
# The selected stream is exhausted. Replace it with another one,
# and return a sample from the newly opened stream.
sampled_stream, sampled_weight = next(stream_source)
active_streams[stream_pos] = iter(sampled_stream)
active_weights[stream_pos] = sampled_weight
item = next(active_streams[stream_pos])
yield item


class LazyShuffler(ImitatesDict):
"""
A wrapper over an iterable that enables lazy shuffling.
Expand Down
8 changes: 8 additions & 0 deletions lhotse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import os
import random
import secrets
import sys
import urllib
import uuid
Expand Down Expand Up @@ -1092,3 +1093,10 @@ def type_cast_value(self, ctx, value):

def is_torchaudio_available() -> bool:
return is_module_available("torchaudio")


def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random:
if seed == "trng":
return secrets.SystemRandom()
else:
return random.Random(seed)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pickle

from lhotse import CutSet
from lhotse.lazy import LazyIteratorMultiplexer
from lhotse.lazy import LazyInfiniteApproximateMultiplexer, LazyIteratorMultiplexer
from lhotse.testing.dummies import DummyManifest


Expand Down Expand Up @@ -86,3 +86,88 @@ def test_multiplexer_with_cuts_pickling():
mux_rec = pickle.loads(data)

assert list(mux) == list(mux_rec)


def test_multiplexer_max_open_streams():
mux = LazyInfiniteApproximateMultiplexer(
range(3),
range(10, 13),
range(100, 103),
seed=1,
max_open_streams=2,
)

it = iter(mux)
samples = []
for _ in range(9):
samples.append(next(it))

# Remember we are sampling with replacement when using
# max_open_streams. Here, the following streams were picked:
# stream2 [100-102],
# stream0 [0-2]
# stream0 [0-2]
# stream1 [10-12]
assert samples == [100, 0, 1, 2, 101, 102, 0, 10, 11]


def test_multiplexer_max_open_streams_1():
mux = LazyInfiniteApproximateMultiplexer(
range(3),
range(10, 13),
range(100, 103),
seed=1,
max_open_streams=1,
)

it = iter(mux)
samples = []
for _ in range(9):
samples.append(next(it))

# When max_open_streams=1, mux is reduced to a chain
assert samples == [0, 1, 2, 10, 11, 12, 0, 1, 2]


def test_cut_set_infinite_mux():
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3)
cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13)
cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103)

cuts_mux = CutSet.infinite_mux(
cuts1, cuts2, cuts3, seed=0, max_open_streams=2, weights=[0.6, 0.3, 0.1]
)

def cid(i: int) -> str:
return f"dummy-mono-cut-{i:04d}"

samples = []
for i, cut in enumerate(cuts_mux):
if i == 20:
break
samples.append(cut)

expected_ids = (
10,
11,
10,
12,
11,
0,
1,
12,
2,
10,
0,
1,
2,
100,
11,
12,
101,
0,
1,
2,
)

assert [c.id for c in samples] == [cid(i) for i in expected_ids]

0 comments on commit 2f8ed44

Please sign in to comment.