From 4680989e94c5af16b1a73c254375c0e45be47db6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 22 Jan 2024 13:28:53 -0500 Subject: [PATCH 1/8] Simplify the implementation of DurationBatcher. Avoids caching cuts for future re-use. --- lhotse/dataset/sampling/base.py | 5 +- lhotse/dataset/sampling/dynamic.py | 32 +++----- .../sampling/test_dynamic_bucketing.py | 8 +- test/dataset/test_controllable_weights.py | 80 +++++++++++++++++++ 4 files changed, 96 insertions(+), 29 deletions(-) create mode 100644 test/dataset/test_controllable_weights.py diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index 1b5943a20..8f24df362 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -383,10 +383,9 @@ def close_to_exceeding(self) -> bool: if self.max_cuts is not None and self.num_cuts >= self.max_cuts: return True - thresh = self.longest_seen - if self.max_duration is not None: - return self.current + thresh >= self.max_duration - 1e-3 # float precision + effective_duration = (self.num_cuts + 1) * self.longest_seen + return effective_duration > self.max_duration return False def reset(self) -> None: diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index b8cc50f6a..1e0466461 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -291,11 +291,8 @@ def detuplify( while True: # Check that we have not reached the end of the dataset. try: - if self.reuse_cuts_buffer: - next_cut_or_tpl = self.reuse_cuts_buffer.popleft() - else: - # If this doesn't raise (typical case), it's not the end: keep processing. - next_cut_or_tpl = next(self.cuts_iter) + # If this doesn't raise (typical case), it's not the end: keep processing. + next_cut_or_tpl = next(self.cuts_iter) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. @@ -315,6 +312,7 @@ def detuplify( raise StopIteration() # Track the duration/frames/etc. constraints. + cuts.append(next_cut_or_tpl) self.time_constraint.add( next_cut_or_tpl[0] if isinstance(next_cut_or_tpl, tuple) @@ -322,25 +320,15 @@ def detuplify( ) # Did we exceed the max_frames and max_cuts constraints? - if not self.time_constraint.exceeded(): - # No - add the next cut to the batch, and keep trying. - cuts.append(next_cut_or_tpl) - else: - # Yes. Do we have at least one cut in the batch? - if cuts: - # Yes. Return the batch, but keep the currently drawn cut for later. - self.reuse_cuts_buffer.append(next_cut_or_tpl) - break - else: - # No. We'll warn the user that the constrains might be too tight, - # and return the cut anyway. + if self.time_constraint.close_to_exceeding(): + # Yes. Finish sampling this batch. + if self.time_constraint.exceeded(): warnings.warn( - "The first cut drawn in batch collection violates " - "the max_frames, max_cuts, or max_duration constraints - " - "we'll return it anyway. " - "Consider increasing max_frames/max_cuts/max_duration." + "We have exceeded the max_duration constraint during sampling. " + "This is likely because max_duration was set to a very low value ~10s, " + "or you're using a CutSet with very long cuts (e.g. 100s of seconds long)." ) - cuts.append(next_cut_or_tpl) + break return detuplify(cuts) diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index d717961c4..e7337bf88 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -174,8 +174,8 @@ def test_dynamic_bucketing_sampler_precomputed_duration_bins(): # We sampled 5 batches with this RNG, like the following: assert len(batches) == 5 - assert len(batches[0]) == 2 - assert sum(c.duration for c in batches[0]) == 2 + assert len(batches[0]) == 3 + assert sum(c.duration for c in batches[0]) == 4 assert len(batches[1]) == 2 assert sum(c.duration for c in batches[1]) == 3 @@ -186,8 +186,8 @@ def test_dynamic_bucketing_sampler_precomputed_duration_bins(): assert len(batches[3]) == 2 assert sum(c.duration for c in batches[3]) == 3 - assert len(batches[4]) == 2 - assert sum(c.duration for c in batches[4]) == 4 + assert len(batches[4]) == 1 + assert sum(c.duration for c in batches[4]) == 2 def test_dynamic_bucketing_sampler_max_duration_and_max_cuts(): diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py new file mode 100644 index 000000000..e6b7a7ac3 --- /dev/null +++ b/test/dataset/test_controllable_weights.py @@ -0,0 +1,80 @@ +import torch + +from lhotse import CutSet +from lhotse.dataset import DynamicCutSampler, IterableDatasetWrapper +from lhotse.testing.dummies import DummyManifest + + +class DummyDataset(torch.utils.data.Dataset): + def __getitem__(self, item): + return item + + +def test_mux_with_controllable_weights(): + def mark(val: int): + def _inner(cut): + cut.source = val + return cut + + return _inner + + # 3 infinite iterables + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0)).repeat() + cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1)).repeat() + cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2)).repeat() + + def assert_sources_are(cuts: CutSet, expected: list[int]): + actual = [c.source for c in cuts] + assert actual == expected + + # TODO: initialize weights + weights = [1, 0, 0] + + muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) + + sampler = DynamicCutSampler(muxd, max_cuts=2) + + # locate the sampler in a sub-process + dloader = torch.utils.data.DataLoader( + dataset=DummyDataset(), + sampler=sampler, + batch_size=None, + num_workers=0, + ) + + dloader = iter(dloader) + b = next(dloader) + assert_sources_are(b, [0, 0]) + + # TODO: set the weight + weights[0] = 0 + weights[1] = 1 + b = next(dloader) + assert_sources_are(b, [1, 1]) + + # TODO: set the weight + weights[1] = 0 + weights[2] = 1 + b = next(dloader) + assert_sources_are(b, [2, 2]) + + +def test_mux_with_controllable_weights_multiprocess(): + return + # # 3 infinite iterables + # cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).repeat() + # cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).repeat() + # cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).repeat() + # + # weights = [1, 1, 1] + # + # muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) + # + # sampler = DynamicCutSampler(muxd, max_cuts=2) + # + # # locate the sampler in a sub-process + # dloader = torch.utils.data.DataLoader( + # dataset=IterableDatasetWrapper(dataset=DummyDataset(), sampler=sampler), + # batch_size=None, + # num_workers=1, + # ) From bbabca5a917b02c224b7b58d8a241fd67ca9f70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 22 Jan 2024 15:41:10 -0500 Subject: [PATCH 2/8] Enable leveraging shared memory for updating mux weights in dataloading subprocesses --- README.md | 1 + lhotse/lazy.py | 14 ++- test/dataset/test_controllable_weights.py | 143 ++++++++++++++++------ 3 files changed, 118 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 0eea7366e..2602ca33c 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,7 @@ Lhotse uses several environment variables to customize it's behavior. They are a - `LHOTSE_AUDIO_DURATION_MISMATCH_TOLERANCE` - used when we load audio from a file and receive a different number of samples than declared in `Recording.num_samples`. This is sometimes necessary because different codecs (or even different versions of the same codec) may use different padding when decoding compressed audio. Typically values up to 0.1, or even 0.3 (second) are still reasonable, and anything beyond that indicates a serious issue. - `LHOTSE_AUDIO_BACKEND` - may be set to any of the values returned from CLI `lhotse list-audio-backends` to override the default behavior of trial-and-error and always use a specific audio backend. - `LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE` - when set to `1` we'll emit full exception stack traces when every available audio backend fails to load a given file (they might be very large). +- `LHOTSE_DILL_ENABLED` - when it's set to `1|True|true|yes`, we will enable `dill`-based serialization of `CutSet` and `Sampler` across processes (it's disabled by default even when `dill` is installed). - `LHOTSE_PREPARING_RELEASE` - used internally by developers when releasing a new version of Lhotse. - `TORCHAUDIO_USE_BACKEND_DISPATCHER` - when set to `1` and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio. - `RANK`, `WORLD_SIZE`, `WORKER`, and `NUM_WORKERS` are internally used to inform Lhotse Shar dataloading subprocesses. diff --git a/lhotse/lazy.py b/lhotse/lazy.py index d5bd55528..67b0b8a66 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -1,5 +1,5 @@ +import os import random -import secrets import types import warnings from functools import partial @@ -186,8 +186,13 @@ class Dillable: If ``dill`` is not installed, it defers to what ``pickle`` does by default. """ + _ENABLED_VALUES = {"1", "True", "true", "yes"} + def __getstate__(self): - if is_module_available("dill"): + if ( + is_module_available("dill") + and os.environ.get("LHOTSE_DILL_ENABLED", "0") in self._ENABLED_VALUES + ): import dill return dill.dumps(self.__dict__) @@ -195,7 +200,10 @@ def __getstate__(self): return self.__dict__ def __setstate__(self, state): - if is_module_available("dill"): + if ( + is_module_available("dill") + and os.environ.get("LHOTSE_DILL_ENABLED", "0") not in self._ENABLED_VALUES + ): import dill self.__dict__ = dill.loads(state) diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py index e6b7a7ac3..d6f359727 100644 --- a/test/dataset/test_controllable_weights.py +++ b/test/dataset/test_controllable_weights.py @@ -1,3 +1,5 @@ +import numpy as np +import pytest import torch from lhotse import CutSet @@ -10,71 +12,138 @@ def __getitem__(self, item): return item -def test_mux_with_controllable_weights(): - def mark(val: int): - def _inner(cut): - cut.source = val - return cut +def mark(val: int): + def _inner(cut): + cut.source = val + return cut - return _inner + return _inner + + +def assert_sources_are(cuts: CutSet, expected: list[int]): + actual = [c.source for c in cuts] + assert actual == expected + + +@pytest.mark.parametrize("weight_type", [list, np.array, torch.tensor]) +def test_mux_with_controllable_weights(weight_type): + """The sampler and the worker are both in the main process.""" # 3 infinite iterables cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0)).repeat() cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1)).repeat() cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2)).repeat() - def assert_sources_are(cuts: CutSet, expected: list[int]): - actual = [c.source for c in cuts] - assert actual == expected + weights = weight_type([1, 0, 0]) + muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) - # TODO: initialize weights - weights = [1, 0, 0] + dloader = torch.utils.data.DataLoader( + dataset=DummyDataset(), + sampler=DynamicCutSampler(muxd, max_cuts=2), + batch_size=None, + num_workers=0, + ) - muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) + dloader = iter(dloader) + b = next(dloader) + assert_sources_are(b, [0, 0]) - sampler = DynamicCutSampler(muxd, max_cuts=2) + weights[0] = 0 + weights[1] = 1 + b = next(dloader) + assert_sources_are(b, [1, 1]) + + weights[1] = 0 + weights[2] = 1 + b = next(dloader) + assert_sources_are(b, [2, 2]) + + +def test_mux_with_controllable_weights_subprocess_worker(): + """ + The sampler is in the main process but the worker is in a sub-process. + + In general expect a latency of ``prefetch_factor * num_workers`` in the propagation + of weights between the main process and the dataloading subprocesses. + """ + + # 3 infinite iterables + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0)).repeat() + cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1)).repeat() + cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2)).repeat() + + weights = [1, 0, 0] + muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) - # locate the sampler in a sub-process dloader = torch.utils.data.DataLoader( dataset=DummyDataset(), - sampler=sampler, + sampler=DynamicCutSampler(muxd, max_cuts=2), batch_size=None, - num_workers=0, + num_workers=1, + prefetch_factor=1, ) dloader = iter(dloader) b = next(dloader) assert_sources_are(b, [0, 0]) - # TODO: set the weight weights[0] = 0 weights[1] = 1 b = next(dloader) + assert_sources_are( + b, [0, 0] + ) # prefetch_factor causes one batch with previous weights to be retained + b = next(dloader) assert_sources_are(b, [1, 1]) - # TODO: set the weight weights[1] = 0 weights[2] = 1 b = next(dloader) + assert_sources_are( + b, [1, 1] + ) # prefetch_factor causes one batch with previous weights to be retained + b = next(dloader) assert_sources_are(b, [2, 2]) -def test_mux_with_controllable_weights_multiprocess(): - return - # # 3 infinite iterables - # cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).repeat() - # cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).repeat() - # cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).repeat() - # - # weights = [1, 1, 1] - # - # muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) - # - # sampler = DynamicCutSampler(muxd, max_cuts=2) - # - # # locate the sampler in a sub-process - # dloader = torch.utils.data.DataLoader( - # dataset=IterableDatasetWrapper(dataset=DummyDataset(), sampler=sampler), - # batch_size=None, - # num_workers=1, - # ) +def test_mux_with_controllable_weights_subprocess_sampler_shared_memory(): + """ + The sampler is placed in the dataloading subprocess. + + Note: we are using PyTorch shared memory to share the weight tensor across processes. + + In general expect a latency of ``prefetch_factor * num_workers`` in the propagation + of weights between the main process and the dataloading subprocesses. + """ + + # 3 infinite iterables + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0)).repeat() + cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1)).repeat() + cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2)).repeat() + + weights = torch.tensor([1, 0, 0]).share_memory_() + assert weights.is_shared() + muxd = CutSet.mux(cuts1, cuts2, cuts3, weights=weights) + + dloader = torch.utils.data.DataLoader( + dataset=IterableDatasetWrapper( + dataset=DummyDataset(), sampler=DynamicCutSampler(muxd, max_cuts=2) + ), + batch_size=None, + num_workers=1, + prefetch_factor=1, + ) + + dloader = iter(dloader) + b = next(dloader) + assert_sources_are(b, [0, 0]) + + weights[0] = 0.0 + weights[1] = 1.0 + b = next(dloader) + assert_sources_are(b, [1, 1]) + + weights[1] = 0.0 + weights[2] = 1.0 + b = next(dloader) + assert_sources_are(b, [2, 2]) From abb1e52e86d588c4aee478512c4086be649c3928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 22 Jan 2024 16:14:42 -0500 Subject: [PATCH 3/8] Initial partial support for infinite_mux --- lhotse/cut/set.py | 6 +- lhotse/lazy.py | 7 ++- test/dataset/test_controllable_weights.py | 70 ++++++++++++++++++++--- 3 files changed, 72 insertions(+), 11 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 09789f67f..eca056d89 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -2439,7 +2439,7 @@ def modify_ids(self, transform_fn: Callable[[str], str]) -> "CutSet": a new string (new cut ID). :return: a new ``CutSet`` with cuts with modified IDs. """ - return self.map(lambda cut: cut.with_id(transform_fn(cut.id))) + return self.map(partial(_with_id, transform_fn=transform_fn)) def fill_supervisions( self, add_empty: bool = True, shrink_ok: bool = False @@ -3265,6 +3265,10 @@ def _add_features_path_prefix_single(cut, path): return cut.with_features_path_prefix(path) +def _with_id(cut, transform_fn): + return cut.with_id(transform_fn(cut.id)) + + def _call(obj, member_fn: str, *args, **kwargs) -> Callable: return getattr(obj, member_fn)(*args, **kwargs) diff --git a/lhotse/lazy.py b/lhotse/lazy.py index 67b0b8a66..9b2e29f6e 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -202,7 +202,7 @@ def __getstate__(self): def __setstate__(self, state): if ( is_module_available("dill") - and os.environ.get("LHOTSE_DILL_ENABLED", "0") not in self._ENABLED_VALUES + and os.environ.get("LHOTSE_DILL_ENABLED", "0") in self._ENABLED_VALUES ): import dill @@ -481,9 +481,10 @@ def shuffled_streams(): # towards the beginning of an "epoch" and then keep yielding # from it1 shards until the epoch is finished and we can sample # from it0 again... - zipped_iter_weights = list(zip(self.iterators, self.weights)) + indexes = list(range(len(self.iterators))) while True: - yield rng.choices(zipped_iter_weights, self.weights, k=1)[0] + selected = rng.choices(indexes, self.weights, k=1)[0] + yield self.iterators[selected], self.weights[selected] # Initialize an infinite sequence of finite streams. # It is sampled with weights and replacement from ``self.iterators``, diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py index d6f359727..3aeabddb9 100644 --- a/test/dataset/test_controllable_weights.py +++ b/test/dataset/test_controllable_weights.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + import numpy as np import pytest import torch @@ -5,6 +7,7 @@ from lhotse import CutSet from lhotse.dataset import DynamicCutSampler, IterableDatasetWrapper from lhotse.testing.dummies import DummyManifest +from lhotse.testing.random import deterministic_rng class DummyDataset(torch.utils.data.Dataset): @@ -20,13 +23,17 @@ def _inner(cut): return _inner +def random_id(*args): + return str(uuid4()) + + def assert_sources_are(cuts: CutSet, expected: list[int]): actual = [c.source for c in cuts] assert actual == expected @pytest.mark.parametrize("weight_type", [list, np.array, torch.tensor]) -def test_mux_with_controllable_weights(weight_type): +def test_mux_with_controllable_weights(deterministic_rng, weight_type): """The sampler and the worker are both in the main process.""" # 3 infinite iterables @@ -59,7 +66,7 @@ def test_mux_with_controllable_weights(weight_type): assert_sources_are(b, [2, 2]) -def test_mux_with_controllable_weights_subprocess_worker(): +def test_mux_with_controllable_weights_subprocess_worker(deterministic_rng): """ The sampler is in the main process but the worker is in a sub-process. @@ -106,7 +113,9 @@ def test_mux_with_controllable_weights_subprocess_worker(): assert_sources_are(b, [2, 2]) -def test_mux_with_controllable_weights_subprocess_sampler_shared_memory(): +def test_mux_with_controllable_weights_subprocess_sampler_shared_memory( + deterministic_rng, +): """ The sampler is placed in the dataloading subprocess. @@ -138,12 +147,59 @@ def test_mux_with_controllable_weights_subprocess_sampler_shared_memory(): b = next(dloader) assert_sources_are(b, [0, 0]) - weights[0] = 0.0 - weights[1] = 1.0 + weights[:] = torch.tensor([0, 1, 0]) # atomic update + b = next(dloader) + assert_sources_are(b, [1, 1]) + + weights[:] = torch.tensor([0, 0, 1]) # atomic update + b = next(dloader) + assert_sources_are(b, [2, 2]) + + +@pytest.mark.skip( + reason="Infinite mux is not yet fully supported for shared memory weights." +) +def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory( + deterministic_rng, +): + """ + The sampler is placed in the dataloading subprocess. + + Note: we are using PyTorch shared memory to share the weight tensor across processes. + + In general expect a latency of ``prefetch_factor * num_workers`` in the propagation + of weights between the main process and the dataloading subprocesses. + """ + + # 3 infinite iterables + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=3).map(mark(0)) + cuts2 = DummyManifest(CutSet, begin_id=10, end_id=13).map(mark(1)) + cuts3 = DummyManifest(CutSet, begin_id=100, end_id=103).map(mark(2)) + + weights = torch.tensor([1, 0, 0]).share_memory_() + assert weights.is_shared() + # randomize_id is required because infinite_mux may sample the same cut in a mini batch + muxd = CutSet.infinite_mux(cuts1, cuts2, cuts3, weights=weights).modify_ids( + random_id + ) + + dloader = torch.utils.data.DataLoader( + dataset=IterableDatasetWrapper( + dataset=DummyDataset(), sampler=DynamicCutSampler(muxd, max_cuts=2) + ), + batch_size=None, + num_workers=1, + prefetch_factor=1, + ) + + dloader = iter(dloader) + b = next(dloader) + assert_sources_are(b, [0, 0]) + + weights[:] = torch.tensor([0, 1, 0]) # atomic update b = next(dloader) assert_sources_are(b, [1, 1]) - weights[1] = 0.0 - weights[2] = 1.0 + weights[:] = torch.tensor([0, 0, 1]) # atomic update b = next(dloader) assert_sources_are(b, [2, 2]) From 8c53f967a8a550d1d0b3ce76080027ccce2b8fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Jan 2024 09:58:46 -0500 Subject: [PATCH 4/8] Support most `CutSet` operations without dill; fix tests; infinite_mux works --- lhotse/cut/set.py | 128 ++++++++++++++---- lhotse/lazy.py | 24 ++-- lhotse/testing/fixtures.py | 6 + .../dataset/sampling/test_sampler_pickling.py | 9 +- test/dataset/test_controllable_weights.py | 18 ++- test/test_lazy.py | 3 +- 6 files changed, 150 insertions(+), 38 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index eca056d89..e1a503e29 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -956,7 +956,7 @@ def filter_supervisions( :param predicate: A callable that accepts `SupervisionSegment` and returns bool :return: a CutSet with filtered supervisions """ - return self.map(lambda cut: cut.filter_supervisions(predicate)) + return self.map(partial(_filter_supervisions, predicate=predicate)) def merge_supervisions( self, @@ -982,8 +982,10 @@ def merge_supervisions( ``custom_merge_fn(custom_key, [s.custom[custom_key] for s in sups])`` """ return self.map( - lambda cut: cut.merge_supervisions( - merge_policy=merge_policy, custom_merge_fn=custom_merge_fn + partial( + _merge_supervisions, + merge_policy=merge_policy, + custom_merge_fn=custom_merge_fn, ) ) @@ -1341,7 +1343,8 @@ def pad( duration = max(cut.duration for cut in self) return self.map( - lambda cut: cut.pad( + partial( + _pad, duration=duration, num_frames=num_frames, num_samples=num_samples, @@ -1422,7 +1425,8 @@ def extend_by( :return: a new CutSet instance. """ return self.map( - lambda cut: cut.extend_by( + partial( + _extend_by, duration=duration, direction=direction, preserve_id=preserve_id, @@ -1535,7 +1539,9 @@ def resample(self, sampling_rate: int, affix_id: bool = False) -> "CutSet": cut are going to be present in a single manifest). :return: a modified copy of the ``CutSet``. """ - return self.map(lambda cut: cut.resample(sampling_rate, affix_id=affix_id)) + return self.map( + partial(_resample, sampling_rate=sampling_rate, affix_id=affix_id) + ) def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet": """ @@ -1550,7 +1556,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "CutSet": cut are going to be present in a single manifest). :return: a modified copy of the ``CutSet``. """ - return self.map(lambda cut: cut.perturb_speed(factor=factor, affix_id=affix_id)) + return self.map(partial(_perturb_speed, factor=factor, affix_id=affix_id)) def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet": """ @@ -1568,7 +1574,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "CutSet": cut are going to be present in a single manifest). :return: a modified copy of the ``CutSet``. """ - return self.map(lambda cut: cut.perturb_tempo(factor=factor, affix_id=affix_id)) + return self.map(partial(_perturb_tempo, factor=factor, affix_id=affix_id)) def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet": """ @@ -1582,9 +1588,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet": cut are going to be present in a single manifest). :return: a modified copy of the ``CutSet``. """ - return self.map( - lambda cut: cut.perturb_volume(factor=factor, affix_id=affix_id) - ) + return self.map(partial(_perturb_volume, factor=factor, affix_id=affix_id)) def normalize_loudness( self, target: float, mix_first: bool = True, affix_id: bool = True @@ -1599,8 +1603,11 @@ def normalize_loudness( :return: a modified copy of the current ``CutSet``. """ return self.map( - lambda cut: cut.normalize_loudness( - target=target, mix_first=mix_first, affix_id=affix_id + partial( + _normalize_loudness, + target=target, + mix_first=mix_first, + affix_id=affix_id, ) ) @@ -1612,7 +1619,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "CutSet": by affixing it with "_wpe". :return: a modified copy of the current ``CutSet``. """ - return self.map(lambda cut: cut.dereverb_wpe(affix_id=affix_id)) + return self.map(partial(_dereverb_wpe, affix_id=affix_id)) def reverb_rir( self, @@ -1643,7 +1650,8 @@ def reverb_rir( """ rir_recordings = list(rir_recordings) if rir_recordings else None return self.map( - lambda cut: cut.reverb_rir( + partial( + _reverb_rir, rir_recording=random.choice(rir_recordings) if rir_recordings else None, normalize_output=normalize_output, early_only=early_only, @@ -1713,25 +1721,25 @@ def drop_features(self) -> "CutSet": """ Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its extracted features. """ - return self.map(lambda cut: cut.drop_features()) + return self.map(_drop_features) def drop_recordings(self) -> "CutSet": """ Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its recordings. """ - return self.map(lambda cut: cut.drop_recording()) + return self.map(_drop_recordings) def drop_supervisions(self) -> "CutSet": """ Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its supervisions. """ - return self.map(lambda cut: cut.drop_supervisions()) + return self.map(_drop_supervisions) def drop_alignments(self) -> "CutSet": """ Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from the alignments present in its supervisions. """ - return self.map(lambda cut: cut.drop_alignments()) + return self.map(_drop_alignments) def compute_and_store_features( self, @@ -2461,7 +2469,7 @@ def fill_supervisions( of calling this method. """ return self.map( - lambda cut: cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok) + partial(_fill_supervision, add_empty=add_empty, shrink_ok=shrink_ok) ) def map_supervisions( @@ -2473,7 +2481,7 @@ def map_supervisions( :param transform_fn: a function that modifies a supervision as an argument. :return: a new, modified CutSet. """ - return self.map(lambda cut: cut.map_supervisions(transform_fn)) + return self.map(partial(_map_supervisions, transform_fn=transform_fn)) def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet": """ @@ -2483,7 +2491,9 @@ def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet": :param transform_fn: a function that accepts a string and returns a string. :return: a new, modified CutSet. """ - return self.map_supervisions(lambda s: s.transform_text(transform_fn)) + return self.map_supervisions( + partial(_transform_text, transform_fn=transform_fn) + ) def __repr__(self) -> str: try: @@ -3269,8 +3279,78 @@ def _with_id(cut, transform_fn): return cut.with_id(transform_fn(cut.id)) -def _call(obj, member_fn: str, *args, **kwargs) -> Callable: - return getattr(obj, member_fn)(*args, **kwargs) +def _fill_supervision(cut, add_empty, shrink_ok): + return cut.fill_supervision(add_empty=add_empty, shrink_ok=shrink_ok) + + +def _map_supervisions(cut, transform_fn): + return cut.map_supervisions(transform_fn) + + +def _transform_text(sup, transform_fn): + return sup.transform_text(transform_fn) + + +def _filter_supervisions(cut, predicate): + return cut.filter_supervisions(predicate) + + +def _merge_supervisions(cut, merge_policy, custom_merge_fn): + return cut.merge_supervisions( + merge_policy=merge_policy, custom_merge_fn=custom_merge_fn + ) + + +def _pad(cut, *args, **kwargs): + return cut.pad(*args, **kwargs) + + +def _extend_by(cut, *args, **kwargs): + return cut.extend_by(*args, **kwargs) + + +def _resample(cut, *args, **kwargs): + return cut.resample(*args, **kwargs) + + +def _perturb_speed(cut, *args, **kwargs): + return cut.perturb_speed(*args, **kwargs) + + +def _perturb_tempo(cut, *args, **kwargs): + return cut.perturb_speed(*args, **kwargs) + + +def _perturb_volume(cut, *args, **kwargs): + return cut.perturb_speed(*args, **kwargs) + + +def _reverb_rir(cut, *args, **kwargs): + return cut.perturb_speed(*args, **kwargs) + + +def _normalize_loudness(cut, *args, **kwargs): + return cut.normalize_loudness(*args, **kwargs) + + +def _dereverb_wpe(cut, *args, **kwargs): + return cut.dereverb_wpe(*args, **kwargs) + + +def _drop_features(cut, *args, **kwargs): + return cut.drop_features(*args, **kwargs) + + +def _drop_recordings(cut, *args, **kwargs): + return cut.drop_recording(*args, **kwargs) + + +def _drop_alignments(cut, *args, **kwargs): + return cut.drop_alignments(*args, **kwargs) + + +def _drop_supervisions(cut, *args, **kwargs): + return cut.drop_supervisions(*args, **kwargs) def _export_to_shar_single( diff --git a/lhotse/lazy.py b/lhotse/lazy.py index 9b2e29f6e..eb5830409 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -494,20 +494,28 @@ def shuffled_streams(): # Sample the first M active streams to be multiplexed. # As streams get depleted, we will replace them with # new streams sampled from the stream source. - active_streams = [] - active_weights = [] + active_streams = [None] * self.max_open_streams + active_weights = [None] * self.max_open_streams stream_indexes = list(range(self.max_open_streams)) - for _ in range(self.max_open_streams): + + def sample_new_stream_at(pos: int) -> None: sampled_stream, sampled_weight = next(stream_source) - active_streams.append(iter(sampled_stream)) - active_weights.append(sampled_weight) + active_streams[pos] = iter(sampled_stream) + active_weights[pos] = sampled_weight + + for stream_pos in range(self.max_open_streams): + sample_new_stream_at(stream_pos) # The actual multiplexing loop. while True: # Select a stream from the currently active streams. # We actually sample an index so that we know which position # to replace if a stream is exhausted. - stream_pos = rng.choices(stream_indexes, weights=active_weights, k=1)[0] + stream_pos = rng.choices( + stream_indexes, + weights=active_weights if sum(active_weights) > 0 else None, + k=1, + )[0] selected = active_streams[stream_pos] try: # Sample from the selected stream. @@ -516,9 +524,7 @@ def shuffled_streams(): except StopIteration: # The selected stream is exhausted. Replace it with another one, # and return a sample from the newly opened stream. - sampled_stream, sampled_weight = next(stream_source) - active_streams[stream_pos] = iter(sampled_stream) - active_weights[stream_pos] = sampled_weight + sample_new_stream_at(stream_pos) item = next(active_streams[stream_pos]) yield item diff --git a/lhotse/testing/fixtures.py b/lhotse/testing/fixtures.py index ae2785f71..be4c28a16 100644 --- a/lhotse/testing/fixtures.py +++ b/lhotse/testing/fixtures.py @@ -4,6 +4,7 @@ from typing import Dict, List import numpy as np +import pytest import torch from lhotse import ( @@ -22,6 +23,11 @@ from lhotse.utils import Seconds, uuid4 +@pytest.fixture() +def with_dill_enabled(): + os.environ["LHOTSE_ENABLE_DILL"] = "1" + + def random_cut_set(n_cuts=100) -> CutSet: sr = 16000 return CutSet.from_cuts( diff --git a/test/dataset/sampling/test_sampler_pickling.py b/test/dataset/sampling/test_sampler_pickling.py index 7668bdf50..8502407c4 100644 --- a/test/dataset/sampling/test_sampler_pickling.py +++ b/test/dataset/sampling/test_sampler_pickling.py @@ -14,6 +14,8 @@ ) from lhotse.dataset.sampling.dynamic import DynamicCutSampler from lhotse.testing.dummies import DummyManifest +from lhotse.testing.fixtures import with_dill_enabled +from lhotse.utils import is_module_available CUTS = DummyManifest(CutSet, begin_id=0, end_id=100) CUTS_MOD = CUTS.modify_ids(lambda cid: cid + "_alt") @@ -120,8 +122,13 @@ def test_sampler_pickling_with_filter(sampler): assert batches_restored[0][0].id == "dummy-mono-cut-0000" +@pytest.mark.xfail( + not is_module_available("dill"), + reason="This test will fail when 'dill' module is not installed as it won't be able to pickle a closure.", + raises=AttributeError, +) @pytest.mark.parametrize("sampler", create_samplers_to_test_filter()) -def test_sampler_pickling_with_filter_local_closure(sampler): +def test_sampler_pickling_with_filter_local_closure(with_dill_enabled, sampler): selected_id = "dummy-mono-cut-0000" diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py index 3aeabddb9..3538d1a69 100644 --- a/test/dataset/test_controllable_weights.py +++ b/test/dataset/test_controllable_weights.py @@ -156,9 +156,6 @@ def test_mux_with_controllable_weights_subprocess_sampler_shared_memory( assert_sources_are(b, [2, 2]) -@pytest.mark.skip( - reason="Infinite mux is not yet fully supported for shared memory weights." -) def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory( deterministic_rng, ): @@ -196,10 +193,25 @@ def test_infinite_mux_with_controllable_weights_subprocess_sampler_shared_memory b = next(dloader) assert_sources_are(b, [0, 0]) + # Note the latency for several batches. The reason is the following: + # infinite_mux() samples streams with replacement, and at the beginning of the test is samples + # 3x stream #0, which has 3 items each with equal probability. + # It will only start returning items from stream #1 once one of the previous streams is exhausted. weights[:] = torch.tensor([0, 1, 0]) # atomic update b = next(dloader) + assert_sources_are(b, [0, 0]) + b = next(dloader) + assert_sources_are(b, [0, 0]) + b = next(dloader) assert_sources_are(b, [1, 1]) + # The latency strikes again as now we have both streams #0 and #1 open, + # but they have zero weight. It means they will be uniformly sampled until + # one of them is exhausted and a new stream #2 is opened. weights[:] = torch.tensor([0, 0, 1]) # atomic update b = next(dloader) + assert_sources_are(b, [0, 0]) + b = next(dloader) + assert_sources_are(b, [1, 2]) + b = next(dloader) assert_sources_are(b, [2, 2]) diff --git a/test/test_lazy.py b/test/test_lazy.py index e6e18c262..38cfef579 100644 --- a/test/test_lazy.py +++ b/test/test_lazy.py @@ -11,6 +11,7 @@ from lhotse import CutSet, FeatureSet, RecordingSet, SupervisionSet, combine from lhotse.lazy import LazyJsonlIterator from lhotse.testing.dummies import DummyManifest, as_lazy +from lhotse.testing.fixtures import with_dill_enabled from lhotse.utils import fastcopy, is_module_available @@ -235,7 +236,7 @@ def _get_ids(cuts): reason="This test will fail when 'dill' module is not installed as it won't be able to pickle a lambda.", raises=AttributeError, ) -def test_dillable(): +def test_dillable(with_dill_enabled): cuts = DummyManifest(CutSet, begin_id=0, end_id=2) with as_lazy(cuts) as lazy_cuts: lazy_cuts = lazy_cuts.map(lambda c: fastcopy(c, id=c.id + "-random-suffix")) From bb6e19743c340353dee694b69fa3c4836dd5f0d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Jan 2024 10:16:05 -0500 Subject: [PATCH 5/8] Fixes --- docs/getting-started.rst | 2 ++ lhotse/cut/set.py | 6 +++--- lhotse/testing/fixtures.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 1fc030b94..c136faed1 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -125,6 +125,8 @@ Lhotse uses several environment variables to customize it's behavior. They are a * ``LHOTSE_AUDIO_LOADING_EXCEPTION_VERBOSE`` - when set to 1 we'll emit full exception stack traces when every available audio backend fails to load a given file (they might be very large). +* ``LHOTSE_DILL_ENABLED`` - when it's set to ``1|True|true|yes``, we will enable ``dill``-based serialization of ``CutSet`` and ``Sampler`` across processes (it's disabled by default even when ``dill`` is installed). + * ``LHOTSE_PREPARING_RELEASE`` - used internally by developers when releasing a new version of Lhotse. * ``TORCHAUDIO_USE_BACKEND_DISPATCHER`` - when set to 1 and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio. diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index e1a503e29..cb3a44cd2 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -3318,15 +3318,15 @@ def _perturb_speed(cut, *args, **kwargs): def _perturb_tempo(cut, *args, **kwargs): - return cut.perturb_speed(*args, **kwargs) + return cut.perturb_tempo(*args, **kwargs) def _perturb_volume(cut, *args, **kwargs): - return cut.perturb_speed(*args, **kwargs) + return cut.perturb_volume(*args, **kwargs) def _reverb_rir(cut, *args, **kwargs): - return cut.perturb_speed(*args, **kwargs) + return cut.reverb_rir(*args, **kwargs) def _normalize_loudness(cut, *args, **kwargs): diff --git a/lhotse/testing/fixtures.py b/lhotse/testing/fixtures.py index be4c28a16..d6af36220 100644 --- a/lhotse/testing/fixtures.py +++ b/lhotse/testing/fixtures.py @@ -25,7 +25,7 @@ @pytest.fixture() def with_dill_enabled(): - os.environ["LHOTSE_ENABLE_DILL"] = "1" + os.environ["LHOTSE_DILL_ENABLED"] = "1" def random_cut_set(n_cuts=100) -> CutSet: From e30eb472d25ad760a464c7ed0c6617871a20749f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Jan 2024 11:00:03 -0500 Subject: [PATCH 6/8] Fix meeting simulation test --- lhotse/__init__.py | 1 + lhotse/lazy.py | 47 +++++++++++++++---- lhotse/utils.py | 7 +++ .../meeting_simulation/conversational.py | 4 +- .../meeting_simulation/speaker_independent.py | 5 +- 5 files changed, 54 insertions(+), 10 deletions(-) diff --git a/lhotse/__init__.py b/lhotse/__init__.py index 2a41063fc..82bdb809d 100644 --- a/lhotse/__init__.py +++ b/lhotse/__init__.py @@ -16,6 +16,7 @@ from .cut import CutSet, MonoCut, MultiCut, create_cut_set_eager, create_cut_set_lazy from .features import * from .kaldi import load_kaldi_data_dir +from .lazy import dill_enabled, is_dill_enabled, set_dill_enabled from .manipulation import combine, split_parallelize_combine, to_manifest from .qa import fix_manifests, validate, validate_recordings_and_supervisions from .serialization import load_manifest, load_manifest_lazy, store_manifest diff --git a/lhotse/lazy.py b/lhotse/lazy.py index eb5830409..3f56a1d67 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -2,6 +2,7 @@ import random import types import warnings +from contextlib import contextmanager from functools import partial from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union @@ -189,10 +190,7 @@ class Dillable: _ENABLED_VALUES = {"1", "True", "true", "yes"} def __getstate__(self): - if ( - is_module_available("dill") - and os.environ.get("LHOTSE_DILL_ENABLED", "0") in self._ENABLED_VALUES - ): + if is_dill_enabled(): import dill return dill.dumps(self.__dict__) @@ -200,10 +198,7 @@ def __getstate__(self): return self.__dict__ def __setstate__(self, state): - if ( - is_module_available("dill") - and os.environ.get("LHOTSE_DILL_ENABLED", "0") in self._ENABLED_VALUES - ): + if is_dill_enabled(): import dill self.__dict__ = dill.loads(state) @@ -211,6 +206,42 @@ def __setstate__(self, state): self.__dict__ = state +def is_dill_enabled(_ENABLED_VALUES=frozenset(("1", "True", "true", "yes"))) -> bool: + """Returns bool indicating if dill-based pickling in Lhotse is enabled or not.""" + return ( + is_module_available("dill") + and os.environ.get("LHOTSE_DILL_ENABLED", "0") in _ENABLED_VALUES + ) + + +def set_dill_enabled(value: bool) -> None: + """Enable or disable dill-based pickling in Lhotse.""" + assert is_module_available("dill"), ( + "Cannot enable dill because dill is not installed. " + "Please run 'pip install dill' and try again." + ) + # We use os.environ here so that sub-processes / forks will inherit this value + os.environ["LHOTSE_DILL_ENABLED"] = "1" if value else "0" + + +@contextmanager +def dill_enabled(value: bool): + """ + Context manager that overrides the setting of Lhotse's dill-backed pickling + and restores the previous value after exit. + + Example:: + + >>> import pickle + ... with dill_enabled(True): + ... pickle.dump(CutSet(...).filter(lambda c: c.duration < 5), open("cutset.pickle", "wb")) + """ + previous = is_dill_enabled() + set_dill_enabled(value) + yield + set_dill_enabled(previous) + + class ImitatesDict(Dillable): """ Helper base class for lazy iterators defined below. diff --git a/lhotse/utils.py b/lhotse/utils.py index 80558e0b7..3aa993cd8 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -1100,3 +1100,10 @@ def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random: return secrets.SystemRandom() else: return random.Random(seed) + + +_LHOTSE_DILL_ENABLED = False + + +def is_dill_enabled() -> bool: + return _LHOTSE_DILL_ENABLED or os.environ["LHOTSE_DILL_ENABLED"] diff --git a/lhotse/workflows/meeting_simulation/conversational.py b/lhotse/workflows/meeting_simulation/conversational.py index 94782e46c..899f15e5d 100644 --- a/lhotse/workflows/meeting_simulation/conversational.py +++ b/lhotse/workflows/meeting_simulation/conversational.py @@ -6,7 +6,7 @@ import numpy as np from tqdm import tqdm -from lhotse import RecordingSet, SupervisionSet +from lhotse import RecordingSet, SupervisionSet, dill_enabled from lhotse.cut import CutSet, MixedCut, MixTrack from lhotse.cut.set import mix from lhotse.parallel import parallel_map @@ -88,6 +88,7 @@ def _compute_histogram_dist(self, values: np.ndarray) -> Any: hist, bin_edges = np.histogram(values, bins=100, density=True) return rv_histogram((hist, bin_edges)) + @dill_enabled(True) def fit(self, meetings: Optional[SupervisionSet] = None) -> None: """ Learn the distribution of the meeting parameters from a given dataset. @@ -261,6 +262,7 @@ def _create_mixture( tracks = sorted(tracks, key=lambda x: x.offset) return MixedCut(id=str(uuid4()), tracks=tracks) + @dill_enabled(True) def simulate( self, cuts: CutSet, diff --git a/lhotse/workflows/meeting_simulation/speaker_independent.py b/lhotse/workflows/meeting_simulation/speaker_independent.py index e6d1a14a3..83ecd784a 100644 --- a/lhotse/workflows/meeting_simulation/speaker_independent.py +++ b/lhotse/workflows/meeting_simulation/speaker_independent.py @@ -6,7 +6,8 @@ import numpy as np from tqdm import tqdm -from lhotse import RecordingSet, SupervisionSet +import lhotse +from lhotse import RecordingSet, SupervisionSet, dill_enabled from lhotse.cut import CutSet, MixedCut, MixTrack from lhotse.cut.set import mix from lhotse.parallel import parallel_map @@ -51,6 +52,7 @@ def __init__(self, loc: float = 0.0, scale: float = 2.0): def __repr__(self): return self.__class__.__name__ + f"(loc={self.loc}, scale={self.scale})" + @dill_enabled(True) def fit(self, meetings: Optional[SupervisionSet] = None) -> None: """ Learn the distribution of the meeting parameters from a given dataset. @@ -113,6 +115,7 @@ def _create_mixture( tracks.append(track) return MixedCut(id=str(uuid4()), tracks=tracks) + @dill_enabled(True) def simulate( self, cuts: CutSet, From 1cca2b3a01e60f1f5a2f879fc658b9714626ce0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Jan 2024 11:13:54 -0500 Subject: [PATCH 7/8] make py3.8 happy --- test/dataset/test_controllable_weights.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py index 3538d1a69..d5aa2d113 100644 --- a/test/dataset/test_controllable_weights.py +++ b/test/dataset/test_controllable_weights.py @@ -1,3 +1,4 @@ +from typing import List from uuid import uuid4 import numpy as np @@ -27,12 +28,12 @@ def random_id(*args): return str(uuid4()) -def assert_sources_are(cuts: CutSet, expected: list[int]): +def assert_sources_are(cuts: CutSet, expected: List[int]): actual = [c.source for c in cuts] assert actual == expected -@pytest.mark.parametrize("weight_type", [list, np.array, torch.tensor]) +@pytest.mark.parametrize("weight_type", [List, np.array, torch.tensor]) def test_mux_with_controllable_weights(deterministic_rng, weight_type): """The sampler and the worker are both in the main process.""" From 6447d8b57eed70903973c22b8e0080d866ba409e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Jan 2024 11:33:53 -0500 Subject: [PATCH 8/8] fix --- test/dataset/test_controllable_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dataset/test_controllable_weights.py b/test/dataset/test_controllable_weights.py index d5aa2d113..0abdcf1fc 100644 --- a/test/dataset/test_controllable_weights.py +++ b/test/dataset/test_controllable_weights.py @@ -33,7 +33,7 @@ def assert_sources_are(cuts: CutSet, expected: List[int]): assert actual == expected -@pytest.mark.parametrize("weight_type", [List, np.array, torch.tensor]) +@pytest.mark.parametrize("weight_type", [list, np.array, torch.tensor]) def test_mux_with_controllable_weights(deterministic_rng, weight_type): """The sampler and the worker are both in the main process."""