From 551ab85235e8e79be73178dd2083e62fb4512f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 3 Apr 2024 13:18:07 -0400 Subject: [PATCH 1/2] Fix randomness in CutMix transform --- lhotse/cut/set.py | 13 ++++++++----- lhotse/dataset/cut_transforms/mix.py | 22 +++++++++++++++++++--- test/dataset/test_cut_transforms.py | 11 +++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 5afcb2b48..4f0872f3d 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -1670,7 +1670,7 @@ def mix( snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20, preserve_id: Optional[str] = None, mix_prob: float = 1.0, - seed: Union[int, Literal["trng", "randomized"]] = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, ) -> "CutSet": """ @@ -1699,7 +1699,7 @@ def mix( Values lower than 1.0 mean that some cuts in the output will be unchanged. :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results - on each iteration. + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -3460,7 +3460,7 @@ class LazyCutMixer(Dillable): Values lower than 1.0 mean that some cuts in the output will be unchanged. :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results - on each iteration. + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -3478,7 +3478,7 @@ def __init__( snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20, preserve_id: Optional[str] = None, mix_prob: float = 1.0, - seed: Union[int, Literal["trng", "randomized"]] = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, stateful: bool = True, ) -> None: @@ -3506,7 +3506,10 @@ def __init__( def __iter__(self): from lhotse.dataset.dataloading import resolve_seed - rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated) + if isinstance(self.seed, random.Random): + rng = self.seed + else: + rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated) if self.stateful: self.num_times_iterated += 1 diff --git a/lhotse/dataset/cut_transforms/mix.py b/lhotse/dataset/cut_transforms/mix.py index 44c4c6264..2e244281f 100644 --- a/lhotse/dataset/cut_transforms/mix.py +++ b/lhotse/dataset/cut_transforms/mix.py @@ -1,7 +1,9 @@ +import random import warnings -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union from lhotse import CutSet +from lhotse.dataset.dataloading import resolve_seed from lhotse.utils import Decibels @@ -18,7 +20,7 @@ def __init__( p: float = 0.5, pad_to_longest: bool = True, preserve_id: bool = False, - seed: int = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, ) -> None: """ @@ -34,6 +36,9 @@ def __init__( to match the duration of the longest Cut in a batch. :param preserve_id: When ``True``, preserves the IDs the cuts had before augmentation. Otherwise, new random IDs are generated for the augmented cuts (default). + :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. + If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -48,6 +53,7 @@ def __init__( self.pad_to_longest = pad_to_longest self.preserve_id = preserve_id self.seed = seed + self.rng = None self.random_mix_offset = random_mix_offset def __call__(self, cuts: CutSet) -> CutSet: @@ -56,6 +62,8 @@ def __call__(self, cuts: CutSet) -> CutSet: if len(self.cuts) == 0: return cuts + self._lazy_rng_init() + maybe_max_duration = ( max(c.duration for c in cuts) if self.pad_to_longest else None ) @@ -65,6 +73,14 @@ def __call__(self, cuts: CutSet) -> CutSet: snr=self.snr, mix_prob=self.p, preserve_id="left" if self.preserve_id else None, - seed=self.seed, + seed=self.rng, random_mix_offset=self.random_mix_offset, ).to_eager() + + def _lazy_rng_init(self): + if self.rng is not None: + return + if isinstance(self.seed, random.Random): + self.rng = self.seed + else: + self.rng = random.Random(resolve_seed(self.seed)) diff --git a/test/dataset/test_cut_transforms.py b/test/dataset/test_cut_transforms.py index 9e7a90955..fd95b6135 100644 --- a/test/dataset/test_cut_transforms.py +++ b/test/dataset/test_cut_transforms.py @@ -118,6 +118,17 @@ def test_cutmix(preserve_id: bool): ) +def test_cut_mix_is_stateful(): + speech_cuts = DummyManifest(CutSet, begin_id=0, end_id=10) + noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=102) + + # called twice on the same input, expecting different results + tfnm = CutMix(noise_cuts, snr=None, p=1.0, seed=0, preserve_id=True) + out1 = tfnm(speech_cuts) + out2 = tfnm(speech_cuts) + assert list(out1) != list(out2) + + def test_cutmix_random_mix_offset(): speech_cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json").resample(16000) noise_cuts = CutSet.from_json("test/fixtures/libri/cuts.json") From 4ee09940b054c4403ec5bbfa496e7b6ce18db2e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 3 Apr 2024 14:05:23 -0400 Subject: [PATCH 2/2] nitpicks --- test/dataset/test_cut_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dataset/test_cut_transforms.py b/test/dataset/test_cut_transforms.py index fd95b6135..1c24763a4 100644 --- a/test/dataset/test_cut_transforms.py +++ b/test/dataset/test_cut_transforms.py @@ -123,9 +123,9 @@ def test_cut_mix_is_stateful(): noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=102) # called twice on the same input, expecting different results - tfnm = CutMix(noise_cuts, snr=None, p=1.0, seed=0, preserve_id=True) - out1 = tfnm(speech_cuts) - out2 = tfnm(speech_cuts) + tnfm = CutMix(noise_cuts, snr=None, p=1.0, seed=0, preserve_id=True) + out1 = tnfm(speech_cuts) + out2 = tnfm(speech_cuts) assert list(out1) != list(out2)