diff --git a/lhotse/testing/random.py b/lhotse/testing/random.py new file mode 100644 index 000000000..495d6df37 --- /dev/null +++ b/lhotse/testing/random.py @@ -0,0 +1,36 @@ +import random + +import numpy as np +import pytest +import torch + + +@pytest.fixture +def deterministic_rng(): + """ + Pytest fixture that ensures deterministic RNG behavior. + After the test finishes, it restores the previous RNG state. + + Example usage:: + + >>> def my_test(deterministic_rng): + ... x = torch.randn(10, 5) # always has the same values + + .. note: Learn more about pytest fixtures setup/teardown here: + https://docs.pytest.org/en/latest/how-to/fixtures.html#teardown-cleanup-aka-fixture-finalization + """ + SEED = 0 + + torch_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + + torch.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + + yield SEED + + random.setstate(py_state) + np.random.set_state(np_state) + torch.set_rng_state(torch_state) diff --git a/test/augmentation/test_torchaudio.py b/test/augmentation/test_torchaudio.py index e2b7a2855..4f66c0d52 100644 --- a/test/augmentation/test_torchaudio.py +++ b/test/augmentation/test_torchaudio.py @@ -24,6 +24,7 @@ volume, ) from lhotse.augmentation.utils import FastRandomRIRGenerator +from lhotse.testing.random import deterministic_rng SAMPLING_RATE = 16000 @@ -34,7 +35,7 @@ def mono_audio(): @pytest.fixture -def multi_channel_audio(): +def multi_channel_audio(deterministic_rng): x = ( torch.sin(2 * math.pi * torch.linspace(0, 1, 16000)) .unsqueeze(0) diff --git a/test/cut/test_cut_truncate.py b/test/cut/test_cut_truncate.py index 80ca86bcc..b86f30c9d 100644 --- a/test/cut/test_cut_truncate.py +++ b/test/cut/test_cut_truncate.py @@ -8,6 +8,7 @@ from lhotse.features import Features from lhotse.supervision import SupervisionSegment, SupervisionSet from lhotse.testing.dummies import DummyManifest, dummy_cut, dummy_recording +from lhotse.testing.random import deterministic_rng @pytest.fixture @@ -255,7 +256,7 @@ def test_truncate_cut_set_offset_end(cut_set): assert isclose(cut2.duration, 5.0) -def test_truncate_cut_set_offset_random(cut_set): +def test_truncate_cut_set_offset_random(deterministic_rng, cut_set): truncated_cut_set = cut_set.truncate(max_duration=5, offset_type="random") cut1, cut2 = truncated_cut_set assert 0.0 <= cut1.start <= 5.0 @@ -270,7 +271,7 @@ def test_truncate_cut_set_offset_random(cut_set): @pytest.mark.parametrize("use_rng", [False, True]) -def test_truncate_cut_set_offset_random_rng(use_rng): +def test_truncate_cut_set_offset_random_rng(deterministic_rng, use_rng): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=30) cuts2 = DummyManifest(CutSet, begin_id=0, end_id=30) diff --git a/test/dataset/test_signal_transforms.py b/test/dataset/test_signal_transforms.py index 320dba8b0..df9b46795 100644 --- a/test/dataset/test_signal_transforms.py +++ b/test/dataset/test_signal_transforms.py @@ -7,6 +7,7 @@ from lhotse.dataset import GlobalMVN, RandomizedSmoothing, SpecAugment from lhotse.dataset.collation import collate_features from lhotse.dataset.signal_transforms import DereverbWPE +from lhotse.testing.random import deterministic_rng from lhotse.utils import is_module_available @@ -58,7 +59,9 @@ def test_specaugment_2d_input_raises_error(): @pytest.mark.parametrize("num_feature_masks", [0, 1, 2]) @pytest.mark.parametrize("num_frame_masks", [1, 2, 3]) -def test_specaugment_3d_input_works(num_feature_masks, num_frame_masks): +def test_specaugment_3d_input_works( + deterministic_rng, num_feature_masks, num_frame_masks +): cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") feats, feat_lens = collate_features(cuts) tfnm = SpecAugment( @@ -92,6 +95,7 @@ def test_specaugment_state_dict(): def test_specaugment_load_state_dict(): + torch.manual_seed(0) # all values non-default config = dict( time_warp_factor=85, @@ -110,7 +114,8 @@ def test_specaugment_load_state_dict(): @pytest.mark.parametrize("sample_sigma", [True, False]) -def test_randomized_smoothing(sample_sigma): +def test_randomized_smoothing(deterministic_rng, sample_sigma): + torch.manual_seed(0) audio = torch.zeros(64, 4000, dtype=torch.float32) tfnm = RandomizedSmoothing(sigma=0.1, sample_sigma=sample_sigma, p=0.8) audio_aug = tfnm(audio) @@ -125,7 +130,7 @@ def test_randomized_smoothing(sample_sigma): assert len(set(audio_aug.sum(dim=1).tolist())) > 1 -def test_randomized_smoothing_p1(): +def test_randomized_smoothing_p1(deterministic_rng): audio = torch.zeros(64, 4000, dtype=torch.float32) tfnm = RandomizedSmoothing(sigma=0.1, p=1.0) audio_aug = tfnm(audio) @@ -137,7 +142,7 @@ def test_randomized_smoothing_p1(): assert (audio_aug[0] != audio_aug[1]).any() -def test_randomized_smoothing_p0(): +def test_randomized_smoothing_p0(deterministic_rng): audio = torch.zeros(64, 4000, dtype=torch.float32) tfnm = RandomizedSmoothing(sigma=0.1, p=0.0) audio_aug = tfnm(audio) @@ -149,7 +154,7 @@ def test_randomized_smoothing_p0(): assert (audio_aug[0] == audio_aug[1]).all() -def test_randomized_smoothing_schedule(): +def test_randomized_smoothing_schedule(deterministic_rng): audio = torch.zeros(16, 16000, dtype=torch.float32) tfnm = RandomizedSmoothing(sigma=[(0, 0.01), (100, 0.5)], p=0.8) audio_aug = tfnm(audio) @@ -172,7 +177,7 @@ def test_randomized_smoothing_schedule(): @pytest.mark.skipif( not is_module_available("nara_wpe"), reason="Requires nara_wpe to be installed." ) -def test_wpe_single_channel(): +def test_wpe_single_channel(deterministic_rng): B, T = 16, 32000 audio = torch.randn(B, T, dtype=torch.float32) tfnm = DereverbWPE() @@ -186,7 +191,7 @@ def test_wpe_single_channel(): @pytest.mark.skipif( not is_module_available("nara_wpe"), reason="Requires nara_wpe to be installed." ) -def test_wpe_multi_channel(): +def test_wpe_multi_channel(deterministic_rng): B, D, T = 16, 2, 32000 audio = torch.randn(B, D, T, dtype=torch.float32) tfnm = DereverbWPE()