Skip to content

Commit

Permalink
test some exception in pyhaaqi
Browse files Browse the repository at this point in the history
  • Loading branch information
groadabike committed Feb 23, 2024
1 parent f2f6461 commit 3c99697
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 26 deletions.
2 changes: 1 addition & 1 deletion clarity/evaluator/ha/earmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def process_reference(
each band converted to dB SL
"""
num_samples = len(signal)

self.start_signal, self.end_signal = self.find_noiseless_boundaries(signal)
signal = signal[self.start_signal : self.end_signal + 1]

Expand Down Expand Up @@ -946,6 +945,7 @@ def group_delay_compensate(
# Add delay correction to each frequency band
processed = np.zeros_like(input_signal)
npts = input_signal.shape[1]

for n in range(self.num_bands):
ref = input_signal[n]
processed[n] = np.concatenate(
Expand Down
20 changes: 20 additions & 0 deletions clarity/evaluator/ha/pyhaaqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def set_reference(
raise ValueError("Audiogram must be set before calling this method.")

if sample_rate != self.EAR_SAMPLE_RATE:
logger.warning(
"Sample rate of the reference signal is different from the "
"ear model sample rate. Resampling."
)
reference = resample(reference, sample_rate, self.EAR_SAMPLE_RATE)

# Compute Ear model
Expand Down Expand Up @@ -310,6 +314,10 @@ def score(
>>> score, nonlinear, linear, raw = ha.score(enhanced, sr)
"""
if sample_rate != self.EAR_SAMPLE_RATE:
logger.warning(

Check warning on line 317 in clarity/evaluator/ha/pyhaaqi.py

View check run for this annotation

Codecov / codecov/patch

clarity/evaluator/ha/pyhaaqi.py#L317

Added line #L317 was not covered by tests
"Sample rate of the enhanced signal is different from the "
"ear model sample rate. Resampling."
)
enhanced = resample(enhanced, sample_rate, self.EAR_SAMPLE_RATE)

Check warning on line 321 in clarity/evaluator/ha/pyhaaqi.py

View check run for this annotation

Codecov / codecov/patch

clarity/evaluator/ha/pyhaaqi.py#L321

Added line #L321 was not covered by tests

(
Expand Down Expand Up @@ -833,6 +841,10 @@ def bm_covary(
ref_mean_square * proc_mean_squared
)
else:
logger.warning(
"Function bm_covary: Reference mean square is too small, "
"outputs set to 0."
)
signal_cross_covariance[k, 0] = 0.0

# Save the reference MS level
Expand Down Expand Up @@ -869,6 +881,10 @@ def bm_covary(
ref_mean_square * proc_mean_squared
)
else:
logger.warning(
"Function bm_covary: Reference mean square is too small, "
"outputs set to 0."
)
signal_cross_covariance[k, n] = 0.0

reference_mean_square[k, n] = ref_mean_square
Expand Down Expand Up @@ -905,6 +921,10 @@ def bm_covary(
ref_mean_square * proc_mean_squared
)
else:
logger.warning(
"Function bm_covary: Reference mean square is too small, "
"outputs set to 0."
)
signal_cross_covariance[k, nseg - 1] = 0.0

# Save the reference and processed MS level
Expand Down
72 changes: 47 additions & 25 deletions tests/evaluator/ha/test_pyhaaqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,22 @@ def test_score(haaqi_instance, audiogram):
)


def test_score_different_sample_rate(haaqi_instance, audiogram, caplog):
"""Test the score method of the HAAQI_V1 class with different sample rates."""
# Initialize HAAQI_V1 instance
np.random.seed(42)
haqqi_instance = haaqi_instance()
haqqi_instance.set_audiogram(audiogram)

# Generate reference and enhanced signals (example)
reference_signal = np.random.randn(1200)

# Set the reference signal
haqqi_instance.set_reference(reference_signal, 48000)

assert "Resampling" in caplog.text


def test_linear_model(haaqi_instance, audiogram):
"""Test the linear_model method of the HAAQI_V1 class."""
# Initialize HAAQI_V1 instance
Expand Down Expand Up @@ -553,11 +569,12 @@ def test_bm_covary_ok(haaqi_instance, audiogram):
)


def test_bm_covary_error(haaqi_instance, audiogram):
@pytest.mark.parametrize("segment_covariance", [1, 2])
def test_bm_covary_error(haaqi_instance, audiogram, segment_covariance):
"""Test bm covary fails when segment size too small"""
haaqi_instance = haaqi_instance(
num_bands=4,
segment_covariance=2,
segment_covariance=segment_covariance,
)
haaqi_instance.set_audiogram(audiogram)

Expand All @@ -577,6 +594,31 @@ def test_bm_covary_error(haaqi_instance, audiogram):
)


def test_bm_covary_ref_meansquare_small(haaqi_instance, audiogram, caplog):
"""Test bm covary fails when segment size too small"""
haaqi_instance = haaqi_instance(
num_bands=4,
segment_covariance=4,
)
haaqi_instance.set_audiogram(audiogram)

np.random.seed(0)
sig_len = 600
reference = np.random.random(size=(4, sig_len)) * 1e-15
processed = np.random.random(size=(4, sig_len))

(
_signal_cross_cov,
_ref_mean_square,
_proc_mean_square,
) = haaqi_instance.bm_covary(
reference_basilar_membrane=reference,
processed_basilar_membrane=reference + 0.4 * processed,
)

assert "Reference mean square is too small" in caplog.text


def test_ave_covary2(haaqi_instance, audiogram):
"""Test ave covary2 method of the HAAQI_V1 class."""
haaqi_instance = haaqi_instance(
Expand Down Expand Up @@ -621,7 +663,7 @@ def test_ave_covary2(haaqi_instance, audiogram):
)


def test_ave_covary2_zero(haaqi_instance, audiogram):
def test_ave_covary2_zero(haaqi_instance, audiogram, caplog):
"""Test ave covary2 method of the HAAQI_V1 class."""
haaqi_instance = haaqi_instance(
num_bands=4,
Expand All @@ -635,34 +677,14 @@ def test_ave_covary2_zero(haaqi_instance, audiogram):
signal_cross_cov = np.random.random(size=(4, sig_len))
ref_mean_square = np.random.random(size=(4, sig_len))

ave_covariance, ihc_sync_covariance = haaqi_instance.ave_covary2(
haaqi_instance.ave_covary2(
signal_cross_covariance=signal_cross_cov,
reference_signal_mean_square=ref_mean_square,
lp_filter_order=np.array([1, 3, 5, 5, 5, 5]),
freq_cutoff=1000 * np.array([1.5, 2.0, 2.5, 3.0, 3.5, 4.0]),
)

assert len(ihc_sync_covariance) == 6

assert ave_covariance == pytest.approx(
0.0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)
assert np.sum(ihc_sync_covariance) == pytest.approx(
0.0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)

assert ihc_sync_covariance == pytest.approx(
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
rel=pytest.rel_tolerance,
abs=pytest.abs_tolerance,
)
assert "Ave signal below threshold, outputs set to 0." in caplog.text


def test_str_representation(haaqi_instance):
Expand Down

0 comments on commit 3c99697

Please sign in to comment.