diff --git a/tests/test_implementation.py b/tests/test_implementation.py index afd69a7..93348bc 100644 --- a/tests/test_implementation.py +++ b/tests/test_implementation.py @@ -19,6 +19,9 @@ from __future__ import unicode_literals from __future__ import absolute_import +import unittest + +import pytest import six import numpy as np from numpy.testing import assert_allclose @@ -28,40 +31,36 @@ from .lspopt_ref import lspopt_ref -class TestLSPOptSuite(object): - - def test_different_N(self): - """Test against reference implementation for different N.""" - def test_fcn(n): - h1, w1 = lspopt(n) - h2, w2 = lspopt_ref(n) - assert_allclose(h1, h2) - assert_allclose(w1, w2) +@pytest.mark.parametrize("n", six.moves.range(64, 1024)) +def test_different_n(n): + """Test against reference implementation for different N.""" + h1, w1 = lspopt(n) + h2, w2 = lspopt_ref(n) + assert_allclose(h1, h2) + assert_allclose(w1, w2) - for n in six.moves.range(64, 1024): - yield test_fcn, n - def test_different_c(self): - """Test against reference implementation for different N.""" - def test_fcn(c): - h1, w1 = lspopt(n=1024, c_parameter=c) - h2, w2 = lspopt_ref(n=1024, c_parameter=c) - assert_allclose(h1, h2) - assert_allclose(w1, w2) +@pytest.mark.parametrize("c", np.arange(1.1, 30.0, 0.1)) +def test_different_c(c): + """Test against reference implementation for different c.""" + h1, w1 = lspopt(n=1024, c_parameter=c) + h2, w2 = lspopt_ref(n=1024, c_parameter=c) + assert_allclose(h1, h2) + assert_allclose(w1, w2) - for c in np.arange(1.1, 30.0, 0.1): - yield test_fcn, c - def test_spectrogram_method(self): +def test_spectrogram_method(): + """Test the spectrogram method's functionality.""" + fs = 10e3 + N = 1e5 + amp = 2 * np.sqrt(2) + noise_power = 0.001 * fs / 2 + time = np.arange(N) / fs + freq = np.linspace(1e3, 2e3, N) + x = amp * chirp(time, 1e3, 2.0, 6e3, method='quadratic') + \ + np.random.normal(scale=np.sqrt(noise_power), size=time.shape) - fs = 10e3 - N = 1e5 - amp = 2 * np.sqrt(2) - noise_power = 0.001 * fs / 2 - time = np.arange(N) / fs - freq = np.linspace(1e3, 2e3, N) - x = amp * chirp(time, 1e3, 2.0, 6e3, method='quadratic') + \ - np.random.normal(scale=np.sqrt(noise_power), size=time.shape) + f, t, Sxx = spectrogram_lspopt(x, fs, c_parameter=20.0) + f_sp, t_sp, Sxx_sp = spectrogram(x, fs) - f, t, Sxx = spectrogram_lspopt(x, fs, c_parameter=20.0) - f_sp, t_sp, Sxx_sp = spectrogram(x, fs) + assert True