Skip to content

Commit

Permalink
Fix non-deterministic tests (#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Jan 13, 2024
1 parent 774ac43 commit 0089643
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions test/features/test_kaldi_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,105 +15,106 @@
_get_strided_batch,
_get_strided_batch_streaming,
)
from lhotse.testing.random import deterministic_rng


def test_wav2win():
def test_wav2win(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2Win()
y, _ = t(x)
assert y.shape == torch.Size([1, 100, 400])
assert y.dtype == torch.float32


def test_wav2fft():
def test_wav2fft(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2FFT()
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.complex64


def test_wav2spec():
def test_wav2spec(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2Spec()
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.float32


def test_wav2logspec():
def test_wav2logspec(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2LogSpec()
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.float32


def test_wav2logfilterbank():
def test_wav2logfilterbank(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2LogFilterBank()
y = t(x)
assert y.shape == torch.Size([1, 100, 80])
assert y.dtype == torch.float32


def test_wav2mfcc():
def test_wav2mfcc(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2MFCC()
y = t(x)
assert y.shape == torch.Size([1, 100, 13])
assert y.dtype == torch.float32


def test_wav2win_is_torchscriptable():
def test_wav2win_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2Win())
y, _ = t(x)
assert y.shape == torch.Size([1, 100, 400])
assert y.dtype == torch.float32


def test_wav2fft_is_torchscriptable():
def test_wav2fft_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2FFT())
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.complex64


def test_wav2spec_is_torchscriptable():
def test_wav2spec_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2Spec())
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.float32


def test_wav2logspec_is_torchscriptable():
def test_wav2logspec_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2LogSpec())
y = t(x)
assert y.shape == torch.Size([1, 100, 257])
assert y.dtype == torch.float32


def test_wav2logfilterbank_is_torchscriptable():
def test_wav2logfilterbank_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2LogFilterBank())
y = t(x)
assert y.shape == torch.Size([1, 100, 80])
assert y.dtype == torch.float32


def test_wav2mfcc_is_torchscriptable():
def test_wav2mfcc_is_torchscriptable(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = torch.jit.script(Wav2MFCC())
y = t(x)
assert y.shape == torch.Size([1, 100, 13])
assert y.dtype == torch.float32


def test_strided_waveform_batch_streaming_snip_edges_false():
def test_strided_waveform_batch_streaming_snip_edges_false(deterministic_rng):
x = torch.arange(16000).unsqueeze(0)
window_length = 400
window_shift = 160
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_strided_waveform_batch_streaming_snip_edges_false():
torch.testing.assert_allclose(y_online, y)


def test_strided_waveform_batch_streaming_snip_edges_true():
def test_strided_waveform_batch_streaming_snip_edges_true(deterministic_rng):
x = torch.arange(16000).unsqueeze(0)
window_length = 400
window_shift = 160
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_strided_waveform_batch_streaming_snip_edges_true():
torch.testing.assert_allclose(y_online, y)


def test_wav2win_streaming():
def test_wav2win_streaming(deterministic_rng):
x = torch.randn(1, 16000, dtype=torch.float32)
t = Wav2Win()
window_length = 400
Expand Down Expand Up @@ -243,7 +244,7 @@ def test_wav2win_streaming():
(Wav2MFCC, 13),
],
)
def test_wav2logfilterbank_streaming(layer_type, feat_dim):
def test_wav2logfilterbank_streaming(deterministic_rng, layer_type, feat_dim):
x = torch.randn(1, 16000, dtype=torch.float32)
t = layer_type()
window_length = 400
Expand Down

0 comments on commit 0089643

Please sign in to comment.