Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix randomness in CutMix transform #1316

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,7 @@ def mix(
snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20,
preserve_id: Optional[str] = None,
mix_prob: float = 1.0,
seed: Union[int, Literal["trng", "randomized"]] = 42,
seed: Union[int, Literal["trng", "randomized"], random.Random] = 42,
random_mix_offset: bool = False,
) -> "CutSet":
"""
Expand Down Expand Up @@ -1699,7 +1699,7 @@ def mix(
Values lower than 1.0 mean that some cuts in the output will be unchanged.
:param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR.
If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results
on each iteration.
on each iteration. You can also directly pass a ``random.Random`` instance here.
:param random_mix_offset: an optional bool.
When ``True`` and the duration of the to be mixed in cut in longer than the original cut,
select a random sub-region from the to be mixed in cut.
Expand Down Expand Up @@ -3460,7 +3460,7 @@ class LazyCutMixer(Dillable):
Values lower than 1.0 mean that some cuts in the output will be unchanged.
:param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR.
If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results
on each iteration.
on each iteration. You can also directly pass a ``random.Random`` instance here.
:param random_mix_offset: an optional bool.
When ``True`` and the duration of the to be mixed in cut in longer than the original cut,
select a random sub-region from the to be mixed in cut.
Expand All @@ -3478,7 +3478,7 @@ def __init__(
snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20,
preserve_id: Optional[str] = None,
mix_prob: float = 1.0,
seed: Union[int, Literal["trng", "randomized"]] = 42,
seed: Union[int, Literal["trng", "randomized"], random.Random] = 42,
random_mix_offset: bool = False,
stateful: bool = True,
) -> None:
Expand Down Expand Up @@ -3506,7 +3506,10 @@ def __init__(
def __iter__(self):
from lhotse.dataset.dataloading import resolve_seed

rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated)
if isinstance(self.seed, random.Random):
rng = self.seed
else:
rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated)
if self.stateful:
self.num_times_iterated += 1

Expand Down
22 changes: 19 additions & 3 deletions lhotse/dataset/cut_transforms/mix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random
import warnings
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

from lhotse import CutSet
from lhotse.dataset.dataloading import resolve_seed
from lhotse.utils import Decibels


Expand All @@ -18,7 +20,7 @@ def __init__(
p: float = 0.5,
pad_to_longest: bool = True,
preserve_id: bool = False,
seed: int = 42,
seed: Union[int, Literal["trng", "randomized"], random.Random] = 42,
random_mix_offset: bool = False,
) -> None:
"""
Expand All @@ -34,6 +36,9 @@ def __init__(
to match the duration of the longest Cut in a batch.
:param preserve_id: When ``True``, preserves the IDs the cuts had before augmentation.
Otherwise, new random IDs are generated for the augmented cuts (default).
:param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR.
If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results
on each iteration. You can also directly pass a ``random.Random`` instance here.
:param random_mix_offset: an optional bool.
When ``True`` and the duration of the to be mixed in cut in longer than the original cut,
select a random sub-region from the to be mixed in cut.
Expand All @@ -48,6 +53,7 @@ def __init__(
self.pad_to_longest = pad_to_longest
self.preserve_id = preserve_id
self.seed = seed
self.rng = None
self.random_mix_offset = random_mix_offset

def __call__(self, cuts: CutSet) -> CutSet:
Expand All @@ -56,6 +62,8 @@ def __call__(self, cuts: CutSet) -> CutSet:
if len(self.cuts) == 0:
return cuts

self._lazy_rng_init()

maybe_max_duration = (
max(c.duration for c in cuts) if self.pad_to_longest else None
)
Expand All @@ -65,6 +73,14 @@ def __call__(self, cuts: CutSet) -> CutSet:
snr=self.snr,
mix_prob=self.p,
preserve_id="left" if self.preserve_id else None,
seed=self.seed,
seed=self.rng,
random_mix_offset=self.random_mix_offset,
).to_eager()

def _lazy_rng_init(self):
if self.rng is not None:
return
if isinstance(self.seed, random.Random):
self.rng = self.seed
else:
self.rng = random.Random(resolve_seed(self.seed))
11 changes: 11 additions & 0 deletions test/dataset/test_cut_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def test_cutmix(preserve_id: bool):
)


def test_cut_mix_is_stateful():
speech_cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=102)

# called twice on the same input, expecting different results
tnfm = CutMix(noise_cuts, snr=None, p=1.0, seed=0, preserve_id=True)
out1 = tnfm(speech_cuts)
out2 = tnfm(speech_cuts)
assert list(out1) != list(out2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test. I ran this using v1.17 which is stateless, but I still got different outputs (i.e., out1 and out2 are different). This is counterintuitive because CutMix passes the same seed in every __call__ to cuts.mix. can you please help me understand this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used CutSet.sample which was not seeded and used the global RNG. If I start any new library from scratch I’ll keep it in mind to have a methodical approach to RNG handling from the beginning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, you're right. I forgot about that. Thanks again.



def test_cutmix_random_mix_offset():
speech_cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json").resample(16000)
noise_cuts = CutSet.from_json("test/fixtures/libri/cuts.json")
Expand Down
Loading