From 3c99697ea3245d42034611d9593cf5bc94f8900e Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 23 Feb 2024 12:19:27 +0000 Subject: [PATCH] test some exception in pyhaaqi --- clarity/evaluator/ha/earmodel.py | 2 +- clarity/evaluator/ha/pyhaaqi.py | 20 +++++++++ tests/evaluator/ha/test_pyhaaqi.py | 72 +++++++++++++++++++----------- 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/clarity/evaluator/ha/earmodel.py b/clarity/evaluator/ha/earmodel.py index 14d4991e1..aa7d7fa38 100644 --- a/clarity/evaluator/ha/earmodel.py +++ b/clarity/evaluator/ha/earmodel.py @@ -225,7 +225,6 @@ def process_reference( each band converted to dB SL """ num_samples = len(signal) - self.start_signal, self.end_signal = self.find_noiseless_boundaries(signal) signal = signal[self.start_signal : self.end_signal + 1] @@ -946,6 +945,7 @@ def group_delay_compensate( # Add delay correction to each frequency band processed = np.zeros_like(input_signal) npts = input_signal.shape[1] + for n in range(self.num_bands): ref = input_signal[n] processed[n] = np.concatenate( diff --git a/clarity/evaluator/ha/pyhaaqi.py b/clarity/evaluator/ha/pyhaaqi.py index ed202c90c..96957d4f2 100644 --- a/clarity/evaluator/ha/pyhaaqi.py +++ b/clarity/evaluator/ha/pyhaaqi.py @@ -250,6 +250,10 @@ def set_reference( raise ValueError("Audiogram must be set before calling this method.") if sample_rate != self.EAR_SAMPLE_RATE: + logger.warning( + "Sample rate of the reference signal is different from the " + "ear model sample rate. Resampling." + ) reference = resample(reference, sample_rate, self.EAR_SAMPLE_RATE) # Compute Ear model @@ -310,6 +314,10 @@ def score( >>> score, nonlinear, linear, raw = ha.score(enhanced, sr) """ if sample_rate != self.EAR_SAMPLE_RATE: + logger.warning( + "Sample rate of the enhanced signal is different from the " + "ear model sample rate. Resampling." + ) enhanced = resample(enhanced, sample_rate, self.EAR_SAMPLE_RATE) ( @@ -833,6 +841,10 @@ def bm_covary( ref_mean_square * proc_mean_squared ) else: + logger.warning( + "Function bm_covary: Reference mean square is too small, " + "outputs set to 0." + ) signal_cross_covariance[k, 0] = 0.0 # Save the reference MS level @@ -869,6 +881,10 @@ def bm_covary( ref_mean_square * proc_mean_squared ) else: + logger.warning( + "Function bm_covary: Reference mean square is too small, " + "outputs set to 0." + ) signal_cross_covariance[k, n] = 0.0 reference_mean_square[k, n] = ref_mean_square @@ -905,6 +921,10 @@ def bm_covary( ref_mean_square * proc_mean_squared ) else: + logger.warning( + "Function bm_covary: Reference mean square is too small, " + "outputs set to 0." + ) signal_cross_covariance[k, nseg - 1] = 0.0 # Save the reference and processed MS level diff --git a/tests/evaluator/ha/test_pyhaaqi.py b/tests/evaluator/ha/test_pyhaaqi.py index 116ea23b5..85ca96b7f 100644 --- a/tests/evaluator/ha/test_pyhaaqi.py +++ b/tests/evaluator/ha/test_pyhaaqi.py @@ -211,6 +211,22 @@ def test_score(haaqi_instance, audiogram): ) +def test_score_different_sample_rate(haaqi_instance, audiogram, caplog): + """Test the score method of the HAAQI_V1 class with different sample rates.""" + # Initialize HAAQI_V1 instance + np.random.seed(42) + haqqi_instance = haaqi_instance() + haqqi_instance.set_audiogram(audiogram) + + # Generate reference and enhanced signals (example) + reference_signal = np.random.randn(1200) + + # Set the reference signal + haqqi_instance.set_reference(reference_signal, 48000) + + assert "Resampling" in caplog.text + + def test_linear_model(haaqi_instance, audiogram): """Test the linear_model method of the HAAQI_V1 class.""" # Initialize HAAQI_V1 instance @@ -553,11 +569,12 @@ def test_bm_covary_ok(haaqi_instance, audiogram): ) -def test_bm_covary_error(haaqi_instance, audiogram): +@pytest.mark.parametrize("segment_covariance", [1, 2]) +def test_bm_covary_error(haaqi_instance, audiogram, segment_covariance): """Test bm covary fails when segment size too small""" haaqi_instance = haaqi_instance( num_bands=4, - segment_covariance=2, + segment_covariance=segment_covariance, ) haaqi_instance.set_audiogram(audiogram) @@ -577,6 +594,31 @@ def test_bm_covary_error(haaqi_instance, audiogram): ) +def test_bm_covary_ref_meansquare_small(haaqi_instance, audiogram, caplog): + """Test bm covary fails when segment size too small""" + haaqi_instance = haaqi_instance( + num_bands=4, + segment_covariance=4, + ) + haaqi_instance.set_audiogram(audiogram) + + np.random.seed(0) + sig_len = 600 + reference = np.random.random(size=(4, sig_len)) * 1e-15 + processed = np.random.random(size=(4, sig_len)) + + ( + _signal_cross_cov, + _ref_mean_square, + _proc_mean_square, + ) = haaqi_instance.bm_covary( + reference_basilar_membrane=reference, + processed_basilar_membrane=reference + 0.4 * processed, + ) + + assert "Reference mean square is too small" in caplog.text + + def test_ave_covary2(haaqi_instance, audiogram): """Test ave covary2 method of the HAAQI_V1 class.""" haaqi_instance = haaqi_instance( @@ -621,7 +663,7 @@ def test_ave_covary2(haaqi_instance, audiogram): ) -def test_ave_covary2_zero(haaqi_instance, audiogram): +def test_ave_covary2_zero(haaqi_instance, audiogram, caplog): """Test ave covary2 method of the HAAQI_V1 class.""" haaqi_instance = haaqi_instance( num_bands=4, @@ -635,34 +677,14 @@ def test_ave_covary2_zero(haaqi_instance, audiogram): signal_cross_cov = np.random.random(size=(4, sig_len)) ref_mean_square = np.random.random(size=(4, sig_len)) - ave_covariance, ihc_sync_covariance = haaqi_instance.ave_covary2( + haaqi_instance.ave_covary2( signal_cross_covariance=signal_cross_cov, reference_signal_mean_square=ref_mean_square, lp_filter_order=np.array([1, 3, 5, 5, 5, 5]), freq_cutoff=1000 * np.array([1.5, 2.0, 2.5, 3.0, 3.5, 4.0]), ) - assert len(ihc_sync_covariance) == 6 - - assert ave_covariance == pytest.approx( - 0.0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance - ) - assert np.sum(ihc_sync_covariance) == pytest.approx( - 0.0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance - ) - - assert ihc_sync_covariance == pytest.approx( - [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ], - rel=pytest.rel_tolerance, - abs=pytest.abs_tolerance, - ) + assert "Ave signal below threshold, outputs set to 0." in caplog.text def test_str_representation(haaqi_instance):