Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce deterministic RNG behavior in repeatedly flaky tests #1143

Merged
merged 7 commits into from
Sep 13, 2023
36 changes: 36 additions & 0 deletions lhotse/testing/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import random

import numpy as np
import pytest
import torch


@pytest.fixture
def deterministic_rng():
"""
Pytest fixture that ensures deterministic RNG behavior.
After the test finishes, it restores the previous RNG state.

Example usage::

>>> def my_test(deterministic_rng):
... x = torch.randn(10, 5) # always has the same values

.. note: Learn more about pytest fixtures setup/teardown here:
https://docs.pytest.org/en/latest/how-to/fixtures.html#teardown-cleanup-aka-fixture-finalization
"""
SEED = 0

torch_state = torch.get_rng_state()
np_state = np.random.get_state()
py_state = random.getstate()

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

yield SEED

random.setstate(py_state)
np.random.set_state(np_state)
torch.set_rng_state(torch_state)
3 changes: 2 additions & 1 deletion test/augmentation/test_torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
volume,
)
from lhotse.augmentation.utils import FastRandomRIRGenerator
from lhotse.testing.random import deterministic_rng

SAMPLING_RATE = 16000

Expand All @@ -34,7 +35,7 @@ def mono_audio():


@pytest.fixture
def multi_channel_audio():
def multi_channel_audio(deterministic_rng):
x = (
torch.sin(2 * math.pi * torch.linspace(0, 1, 16000))
.unsqueeze(0)
Expand Down
5 changes: 3 additions & 2 deletions test/cut/test_cut_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lhotse.features import Features
from lhotse.supervision import SupervisionSegment, SupervisionSet
from lhotse.testing.dummies import DummyManifest, dummy_cut, dummy_recording
from lhotse.testing.random import deterministic_rng


@pytest.fixture
Expand Down Expand Up @@ -255,7 +256,7 @@ def test_truncate_cut_set_offset_end(cut_set):
assert isclose(cut2.duration, 5.0)


def test_truncate_cut_set_offset_random(cut_set):
def test_truncate_cut_set_offset_random(deterministic_rng, cut_set):
truncated_cut_set = cut_set.truncate(max_duration=5, offset_type="random")
cut1, cut2 = truncated_cut_set
assert 0.0 <= cut1.start <= 5.0
Expand All @@ -270,7 +271,7 @@ def test_truncate_cut_set_offset_random(cut_set):


@pytest.mark.parametrize("use_rng", [False, True])
def test_truncate_cut_set_offset_random_rng(use_rng):
def test_truncate_cut_set_offset_random_rng(deterministic_rng, use_rng):
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=30)
cuts2 = DummyManifest(CutSet, begin_id=0, end_id=30)

Expand Down
19 changes: 12 additions & 7 deletions test/dataset/test_signal_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lhotse.dataset import GlobalMVN, RandomizedSmoothing, SpecAugment
from lhotse.dataset.collation import collate_features
from lhotse.dataset.signal_transforms import DereverbWPE
from lhotse.testing.random import deterministic_rng
from lhotse.utils import is_module_available


Expand Down Expand Up @@ -58,7 +59,9 @@ def test_specaugment_2d_input_raises_error():

@pytest.mark.parametrize("num_feature_masks", [0, 1, 2])
@pytest.mark.parametrize("num_frame_masks", [1, 2, 3])
def test_specaugment_3d_input_works(num_feature_masks, num_frame_masks):
def test_specaugment_3d_input_works(
deterministic_rng, num_feature_masks, num_frame_masks
):
cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
feats, feat_lens = collate_features(cuts)
tfnm = SpecAugment(
Expand Down Expand Up @@ -92,6 +95,7 @@ def test_specaugment_state_dict():


def test_specaugment_load_state_dict():
torch.manual_seed(0)
# all values non-default
config = dict(
time_warp_factor=85,
Expand All @@ -110,7 +114,8 @@ def test_specaugment_load_state_dict():


@pytest.mark.parametrize("sample_sigma", [True, False])
def test_randomized_smoothing(sample_sigma):
def test_randomized_smoothing(deterministic_rng, sample_sigma):
torch.manual_seed(0)
audio = torch.zeros(64, 4000, dtype=torch.float32)
tfnm = RandomizedSmoothing(sigma=0.1, sample_sigma=sample_sigma, p=0.8)
audio_aug = tfnm(audio)
Expand All @@ -125,7 +130,7 @@ def test_randomized_smoothing(sample_sigma):
assert len(set(audio_aug.sum(dim=1).tolist())) > 1


def test_randomized_smoothing_p1():
def test_randomized_smoothing_p1(deterministic_rng):
audio = torch.zeros(64, 4000, dtype=torch.float32)
tfnm = RandomizedSmoothing(sigma=0.1, p=1.0)
audio_aug = tfnm(audio)
Expand All @@ -137,7 +142,7 @@ def test_randomized_smoothing_p1():
assert (audio_aug[0] != audio_aug[1]).any()


def test_randomized_smoothing_p0():
def test_randomized_smoothing_p0(deterministic_rng):
audio = torch.zeros(64, 4000, dtype=torch.float32)
tfnm = RandomizedSmoothing(sigma=0.1, p=0.0)
audio_aug = tfnm(audio)
Expand All @@ -149,7 +154,7 @@ def test_randomized_smoothing_p0():
assert (audio_aug[0] == audio_aug[1]).all()


def test_randomized_smoothing_schedule():
def test_randomized_smoothing_schedule(deterministic_rng):
audio = torch.zeros(16, 16000, dtype=torch.float32)
tfnm = RandomizedSmoothing(sigma=[(0, 0.01), (100, 0.5)], p=0.8)
audio_aug = tfnm(audio)
Expand All @@ -172,7 +177,7 @@ def test_randomized_smoothing_schedule():
@pytest.mark.skipif(
not is_module_available("nara_wpe"), reason="Requires nara_wpe to be installed."
)
def test_wpe_single_channel():
def test_wpe_single_channel(deterministic_rng):
B, T = 16, 32000
audio = torch.randn(B, T, dtype=torch.float32)
tfnm = DereverbWPE()
Expand All @@ -186,7 +191,7 @@ def test_wpe_single_channel():
@pytest.mark.skipif(
not is_module_available("nara_wpe"), reason="Requires nara_wpe to be installed."
)
def test_wpe_multi_channel():
def test_wpe_multi_channel(deterministic_rng):
B, D, T = 16, 2, 32000
audio = torch.randn(B, D, T, dtype=torch.float32)
tfnm = DereverbWPE()
Expand Down
Loading