From 0ac7da44034203e1f08951ca4d22840562d2fc3b Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 16:47:59 +0100 Subject: [PATCH 01/49] correlate wrapping to use maxlag Signed-off-by: Gerardo Roa Dabike --- clarity/utils/signal_processing.py | 56 ++++++++++++++++++++++++++ tests/utils/test_signal_processing.py | 58 +++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/clarity/utils/signal_processing.py b/clarity/utils/signal_processing.py index 0e9078ee6..f2c0fb82b 100644 --- a/clarity/utils/signal_processing.py +++ b/clarity/utils/signal_processing.py @@ -1,4 +1,5 @@ """Signal processing utilities.""" +# pylint: disable=import-error from __future__ import annotations import numpy as np @@ -86,3 +87,58 @@ def resample( ) raise ValueError(f"Unknown resampling method: {method}") + + +def correlate( + x: np.ndarray, + y: np.ndarray, + mode="full", + method="auto", + lags: int | float | None = None, +) -> np.ndarray: + """ + Wrap of ``scipy.signal.correlate`` that includes a mode + for maxlag. + + This computes the same result as + numpy.correlate(x, y, mode='full')[len(a)-maxlag-1:len(a)+maxlag] + + Args: + x (np.ndarray): First signal + y (np.ndarray): Second signal + mode (str): Mode to pass to ``scipy.signal.correlate`` + method (str): + 'maxlag': Implement cross correlation with a maximum number of lags. + x and y must have the same length. + based on https://stackoverflow.com/questions/30677241/ + how-to-limit-cross-correlation-window-width-in-numpy + "auto": Run scipy.signal.correlate with method='auto' + 'direct': Run scipy.signal.correlate with method='direct' + 'fft': Run scipy.signal.correlate with method='fft' + lags (int): Maximum number of lags for `method` "maxlag". + Returns: + np.ndarray: cross correlation of x and y + """ + if method == "maxlag": + if lags is None: + raise ValueError("maxlag must be specified for method='maxlag'") + lags = int(lags) + + if x.shape[0] != y.shape[0]: + raise ValueError("x and y must have the same length") + + py = np.pad(y.conj(), 2 * lags, mode="constant") + # pylint: disable=unsubscriptable-object + T = np.lib.stride_tricks.as_strided( + py[2 * lags :], + shape=(2 * lags + 1, len(y) + 2 * lags), + strides=(-py.strides[0], py.strides[0]), + ) + px = np.pad(x, lags, mode="constant") + return T.dot(px) + + if method in ["auto", "direct", "fft"]: + # Run scipy signal correlate with the specified method and mode + return scipy.signal.correlate(x, y, mode=mode, method=method) + + raise ValueError(f"Unknown method: {method}") diff --git a/tests/utils/test_signal_processing.py b/tests/utils/test_signal_processing.py index 62a91946b..6f4506293 100644 --- a/tests/utils/test_signal_processing.py +++ b/tests/utils/test_signal_processing.py @@ -1,9 +1,11 @@ """Test for utils.signal_processing module""" +# pylint: disable=import-error import numpy as np import pytest from clarity.utils.signal_processing import ( compute_rms, + correlate, denormalize_signals, normalize_signal, resample, @@ -206,3 +208,59 @@ def test_resample_with_3d_array_error(): resample( signal=input_signal, sample_rate=16000, new_sample_rate=8000, method="soxr" ) + + +def test_correlate_maxlag(): + """Test the function correlate with maxlag""" + x = np.array([1, 2, 3, 4, 5]) + y = np.array([2, 4, 6, 8, 10]) + maxlag = 1 + expected_result = np.array([80, 110, 80]) + + result = correlate(x, y, lags=maxlag, method="maxlag") + print(result) + assert np.sum(result) == pytest.approx( + np.sum(expected_result), rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + ) + + +def test_correlate_auto(): + """Test the function correlate with auto method""" + x = np.array([1, 2, 3, 4, 5]) + y = np.array([2, 4, 6]) + expected_result = np.array([6, 16, 28, 40, 52, 28, 10]) + result = correlate(x, y, method="auto") + + assert np.sum(result) == pytest.approx( + np.sum(expected_result), rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + ) + + +def test_correlate_unknown_method(): + """Test the function correlate with unknown method""" + x = np.array([1, 2, 3, 4, 5]) + y = np.array([2, 4, 6]) + unknown_method = "invalid_method" + + with pytest.raises(ValueError, match="Unknown method: invalid_method"): + correlate(x, y, method=unknown_method) + + +def test_correlate_missing_maxlag(): + """Test the function correlate with missing maxlag""" + x = np.array([1, 2, 3, 4, 5]) + y = np.array([2, 4, 6, 8, 10]) + + with pytest.raises( + ValueError, match="maxlag must be specified for method='maxlag'" + ): + correlate(x, y, method="maxlag") + + +def test_correlate_maxlag_different_length(): + """Test the function correlate with maxlag and different length""" + x = np.array([1, 2, 3, 4, 5]) + y = np.array([2, 4, 6]) + + with pytest.raises(ValueError, match="x and y must have the same length"): + correlate(x, y, lags=1, method="maxlag") From ccd8661f1e088f4481b164a8febc2d9df92f5f1f Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 16:57:36 +0100 Subject: [PATCH 02/49] Replace correlate in HAAQI functions Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haspi/eb.py | 64 ++++++++++++++++++------------ clarity/utils/signal_processing.py | 3 +- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index 32692db92..a4b0761ab 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -11,7 +11,6 @@ butter, cheby2, convolve, - correlate, firwin, group_delay, lfilter, @@ -20,6 +19,7 @@ from clarity.enhancer.nalr import NALR from clarity.utils.audiogram import Audiogram +from clarity.utils.signal_processing import correlate if TYPE_CHECKING: from numpy import ndarray @@ -1798,23 +1798,27 @@ def bm_covary( window = np.hanning(nwin).conj().transpose() # Raised cosine von Hann window # compute inverted Window autocorrelation - win_corr = correlate(window, window, "full") - start_sample = int(len(window) - 1 - maxlag) - end_sample = int(maxlag + len(window)) - if start_sample < 0: - raise ValueError("segment size too small") - win_corr = 1 / win_corr[start_sample:end_sample] + # win_corr = correlate(window, window, "full") + # start_sample = int(len(window) - 1 - maxlag) + # end_sample = int(maxlag + len(window)) + # if start_sample < 0: + # raise ValueError("segment size too small") + # win_corr = 1 / win_corr[start_sample:end_sample] + win_corr = 1.0 / correlate(window, window, method="maxlag", lags=int(maxlag)) win_sum2 = 1.0 / np.sum(window**2) # Window power, inverted # The first segment has a half window nhalf = int(nwin / 2) half_window = window[nhalf:nwin] - half_corr = correlate(half_window, half_window, "full") - start_sample = int(len(half_window) - 1 - maxlag) - end_sample = int(maxlag + len(half_window)) - if start_sample < 0: - raise ValueError("segment size too small") - half_corr = 1 / half_corr[start_sample:end_sample] + # half_corr = correlate(half_window, half_window, "full") + # start_sample = int(len(half_window) - 1 - maxlag) + # end_sample = int(maxlag + len(half_window)) + # if start_sample < 0: + # raise ValueError("segment size too small") + # half_corr = 1 / half_corr[start_sample:end_sample] + half_corr = 1.0 / correlate( + half_window, half_window, method="maxlag", lags=int(maxlag) + ) halfsum2 = 1.0 / np.sum(half_window**2) # MS sum normalization, first segment # Number of segments @@ -1844,10 +1848,13 @@ def bm_covary( ref_mean_square = np.sum(reference_seg**2) * halfsum2 proc_mean_squared = np.sum(processed_seg**2) * halfsum2 - correlation = correlate(reference_seg, processed_seg, "full") - correlation = correlation[ - int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - ] + # correlation = correlate(reference_seg, processed_seg, "full") + # correlation = correlation[ + # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + # ] + correlation = correlate( + reference_seg, processed_seg, method="maxlag", lags=int(maxlag) + ) unbiased_cross_correlation = np.max(np.abs(correlation * half_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1873,10 +1880,13 @@ def bm_covary( # Normalize signal MS value by the window ref_mean_square = np.sum(reference_seg**2) * win_sum2 proc_mean_squared = np.sum(processed_seg**2) * win_sum2 - correlation = correlate(reference_seg, processed_seg, "full") - correlation = correlation[ - int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - ] + # correlation = correlate(reference_seg, processed_seg, "full") + # correlation = correlation[ + # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + # ] + correlation = correlate( + reference_seg, processed_seg, method="maxlag", lags=int(maxlag) + ) unbiased_cross_correlation = np.max(np.abs(correlation * win_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1900,11 +1910,13 @@ def bm_covary( ref_mean_square = np.sum(reference_seg**2) * halfsum2 proc_mean_squared = np.sum(processed_seg**2) * halfsum2 - correlation = np.correlate(reference_seg, processed_seg, "full") - correlation = correlation[ - int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - ] - + # correlation = correlate(reference_seg, processed_seg, "full") + # correlation = correlation[ + # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + # ] + correlation = correlate( + reference_seg, processed_seg, method="maxlag", lags=int(maxlag) + ) unbiased_cross_correlation = np.max(np.abs(correlation * half_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalized cross-covariance diff --git a/clarity/utils/signal_processing.py b/clarity/utils/signal_processing.py index f2c0fb82b..cc2051bf5 100644 --- a/clarity/utils/signal_processing.py +++ b/clarity/utils/signal_processing.py @@ -109,7 +109,8 @@ def correlate( mode (str): Mode to pass to ``scipy.signal.correlate`` method (str): 'maxlag': Implement cross correlation with a maximum number of lags. - x and y must have the same length. + x and y must have the same length. + `mode` is ignored. based on https://stackoverflow.com/questions/30677241/ how-to-limit-cross-correlation-window-width-in-numpy "auto": Run scipy.signal.correlate with method='auto' From de9879634609e632267326100d7bd17fcc0a9358 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 17:01:29 +0100 Subject: [PATCH 03/49] config task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/config.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index 88de32bf8..b2de180de 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -6,10 +6,12 @@ path: music_valid_file: ${path.metadata_dir}/musdb18.valid.json listeners_train_file: ${path.metadata_dir}/listeners.train.json listeners_valid_file: ${path.metadata_dir}/listeners.valid.json - exp_folder: ./exp # folder to store enhanced signals and final results + exp_folder: ./exp_${separator.model} # folder to store enhanced signals and final results -sample_rate: 44100 +sample_rate: 44100 # sample rate of the input mixture +stem_sample_rate: 24000 # sample rate output stems +remix_sample_rate: 32000 # sample rate for output remixed signal nalr: nfir: 220 @@ -27,7 +29,6 @@ soft_clip: True separator: model: demucs # demucs or openunmix - sources: [drums, bass, other, vocals] device: ~ evaluate: From b510fbd719df411ba693fe9d6113956691a28a5a Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 17:06:13 +0100 Subject: [PATCH 04/49] adding flac encoder Signed-off-by: Gerardo Roa Dabike --- clarity/utils/flac_encoder.py | 262 +++++++++++++++++++++++++++++++ tests/utils/test_flac_encoder.py | 76 +++++++++ 2 files changed, 338 insertions(+) create mode 100644 clarity/utils/flac_encoder.py create mode 100644 tests/utils/test_flac_encoder.py diff --git a/clarity/utils/flac_encoder.py b/clarity/utils/flac_encoder.py new file mode 100644 index 000000000..55718a98c --- /dev/null +++ b/clarity/utils/flac_encoder.py @@ -0,0 +1,262 @@ +""" +Class for encoding and decoding audio signals + using flac compression. +""" +from __future__ import annotations + +import logging +import tempfile + +# pylint: disable=import-error, protected-access +from pathlib import Path + +import numpy as np +import pyflac as pf +import soundfile as sf + +logger = logging.getLogger(__name__) + + +class WavEncoder(pf.encoder._Encoder): + """ + Class offers an adaptation of the pyflac.encoder.FileEncoder + to work directly with WAV signals as input. + + """ + + def __init__( + self, + signal: np.ndarray, + sample_rate: int, + output_file: str | Path | None = None, + compression_level: int = 5, + blocksize: int = 0, + streamable_subset: bool = True, + verify: bool = False, + ) -> None: + """ + Initialise the encoder. + + Args: + signal (np.ndarray): The raw audio data to be encoded. + sample_rate (int): The sample rate of the audio data. + output_file (str | Path | None): Path to the output FLAC file, + a temporary file will be created if unspecified. + compression_level (int): The compression level parameter that + varies from 0 (fastest) to 8 (slowest). The default setting + is 5, see https://en.wikipedia.org/wiki/FLAC for more details. + blocksize (int): The size of the block to be returned in the + callback. The default is 0 which allows libFLAC to determine + the best block size. + streamable_subset (bool): Whether to use the streamable subset for encoding. + If true the encoder will check settings for compatibility. If false, the + settings may take advantage of the full range that the format allows. + verify (bool): If `True`, the encoder will verify it's own + encoded output by feeding it through an internal decoder and + comparing the original signal against the decoded signal. + If a mismatch occurs, the `process` method will raise a + `EncoderProcessException`. Note that this will slow the + encoding process by the extra time required for decoding and comparison. + """ + super().__init__() + + self.__raw_audio = signal + self._sample_rate = sample_rate + + if output_file: + self.__output_file = ( + Path(output_file) if isinstance(output_file, str) else output_file + ) + else: + with tempfile.NamedTemporaryFile(suffix=".flac") as ofile: + self.__output_file = Path(ofile.name) + + self._blocksize = blocksize + self._compression_level = compression_level + self._streamable_subset = streamable_subset + self._verify = verify + self._initialised = False + + def _init(self): + """ + Initialise the encoder to write to a file. + + Raises: + EncoderInitException: if initialisation fails. + """ + c_output_filename = pf.encoder._ffi.new( + "char[]", str(self.__output_file).encode("utf-8") + ) + rc = pf.encoder._lib.FLAC__stream_encoder_init_file( + self._encoder, + c_output_filename, + pf.encoder._lib._progress_callback, + self._encoder_handle, + ) + pf.encoder._ffi.release(c_output_filename) + if rc != pf.encoder._lib.FLAC__STREAM_ENCODER_INIT_STATUS_OK: + raise pf.EncoderInitException(rc) + + self._initialised = True + + def process(self) -> bytes: + """ + Process the audio data from the WAV file. + + Returns: + (bytes): The FLAC encoded bytes. + + Raises: + EncoderProcessException: if an error occurs when processing the samples + """ + super().process(self.__raw_audio) + self.finish() + with open(self.__output_file, "rb") as f: + return f.read() + + +class FileDecoder(pf.decoder.FileDecoder): + def process(self) -> tuple[np.ndarray, int]: + """ + Overwritten version of the process method from the pyflac decoder. + Original process returns stereo signals in float64 format. + + In this version, the data is returned using the original number + of channels and in in16 format. + + Returns: + (tuple): A tuple of the decoded numpy audio array, and the sample rate + of the audio data. + + Raises: + DecoderProcessException: if any fatal read, write, or memory allocation + error occurred (meaning decoding must stop) + """ + result = pf.decoder._lib.FLAC__stream_decoder_process_until_end_of_stream( + self._decoder + ) + if self.state != pf.decoder.DecoderState.END_OF_STREAM and not result: + raise pf.DecoderProcessException(str(self.state)) + + self.finish() + self.__output.close() + return sf.read(str(self.__output_file), always_2d=False, dtype="int16") + + +class FlacEncoder: + """ + Class for encoding and decoding audio signals using FLAC + + It uses the pyflac library to encode and decode the audio data. + And offers convenient methods for encoding and decoding audio data. + """ + + def __init__(self, compression_level: int = 5) -> None: + """ + Initialise the compressor. + + Args: + compression_level (int): The compression level parameter that + varies from 0 (fastest) to 8 (slowest). The default setting + is 5, see https://en.wikipedia.org/wiki/FLAC for more details. + """ + self.compression_level = compression_level + + def encode( + self, + signal: np.ndarray, + sample_rate: int, + output_file: str | Path | None = None, + ) -> bytes: + """ + Method to encode the audio data using FLAC compressor. + + It creates a WavEncoder object and uses it to encode the audio data. + + Args: + signal (np.ndarray): The raw audio data to be compressed. + sample_rate (int): The sample rate of the audio data. + output_file (str | Path): Path to where to + save the output FLAC file. If not specified, a temporary file + will be created. + + Returns: + (bytes): The FLAC encoded audio signal. + + Raises: + ValueError: If the audio signal is not in `np.int16` format. + """ + if signal.dtype != np.int16: + logger.error( + f"FLAC encoder only supports 16-bit integer signals, " + f"but got {signal.dtype}" + ) + raise ValueError( + f"FLAC encoder only supports 16-bit integer signals, " + f"but got {signal.dtype}" + ) + + wav_encoder = WavEncoder( + signal=signal, + sample_rate=sample_rate, + compression_level=self.compression_level, + output_file=output_file, + ) + return wav_encoder.process() + + @staticmethod + def decode(input_filename: Path | str) -> tuple[np.ndarray, float]: + """ + Method to decode a flac file to wav audio data. + + It uses the pyflac library to decode the flac file. + + Args: + input_filename (pathlib.Path | str): Path to the input FLAC file. + + Returns: + (np.ndarray): The raw audio data. + + Raises: + FileNotFoundError: If the flac file to decode does not exist. + """ + input_filename = ( + Path(input_filename) if isinstance(input_filename, str) else input_filename + ) + + if not input_filename.exists(): + logger.error(f"File {input_filename} not found.") + raise FileNotFoundError(f"File {input_filename} not found.") + + decoder = FileDecoder(input_filename) + signal, sample_rate = decoder.process() + + return signal, float(sample_rate) + + +def read_flac_signal(filename: Path) -> tuple[np.ndarray, float]: + """Read a FLAC signal and return it as a numpy array + + Args: + filename (Path): The path to the FLAC file to read. + + Returns: + signal (np.ndarray): The decoded signal. + sample_rate (float): The sample rate of the signal. + """ + # Create encoder object + flac_encoder = FlacEncoder() + + # Decode FLAC file + signal, sample_rate = flac_encoder.decode( + filename, + ) + signal = (signal / 32768.0).astype(np.float32) + + # Load scale factor + if filename.with_suffix(".txt").exists(): + with open(filename.with_suffix(".txt"), encoding="utf-8") as fp: + max_value = float(fp.read()) + # Scale signal + signal *= max_value + return signal, sample_rate diff --git a/tests/utils/test_flac_encoder.py b/tests/utils/test_flac_encoder.py new file mode 100644 index 000000000..85aa1e8f0 --- /dev/null +++ b/tests/utils/test_flac_encoder.py @@ -0,0 +1,76 @@ +"""Tests for the FlacEncoder class.""" +# pylint: disable=import-error +from pathlib import Path +from tempfile import NamedTemporaryFile + +import numpy as np +import pytest + +from clarity.utils.flac_encoder import FlacEncoder, read_flac_signal + + +def test_encode_decode(): + """Test that the FlacEncoder can encode and decode a signal.""" + # create a random signal + np.random.seed(0) + + sample_rate = 8000 + signal = np.random.uniform(-1, 1, int(0.5 * sample_rate)) + signal_int16 = signal * 32768.0 + signal_int16 = np.clip(signal_int16, -32768, 32767).astype(np.int16) + + # write the encoded bytes to a temporary file + with NamedTemporaryFile(suffix=".flac", delete=False) as tmpfile: + encoder = FlacEncoder() + # encode + _ = encoder.encode( + signal_int16, sample_rate=sample_rate, output_file=tmpfile.name + ) + # decode + decoded_signal, decoded_sr = encoder.decode(Path(tmpfile.name)) + + # check that the decoded signal matches the original signal + assert np.sum(signal_int16) == pytest.approx( + np.sum(decoded_signal), + rel=pytest.rel_tolerance, + abs=pytest.abs_tolerance, + ) + + # check that the sample rate of the decoded signal is correct + assert decoded_sr == sample_rate + + +def test_read_flac_signal(tmp_path): + # Create a test FLAC file + np.random.seed(2023) + + filename = tmp_path / "test.flac" + + sample_rate = 16000 + signal = np.random.rand(1600) + scale_factor = np.max(np.abs(signal)) + + signal_scaled = signal / scale_factor + signal_scaled = signal_scaled * 32768.0 + signal_scaled = np.clip(signal_scaled, -32768.0, 32767.0) + signal_scaled = signal_scaled.astype(np.dtype("int16")) + + flac_encoder = FlacEncoder() + flac_encoder.encode(signal_scaled, sample_rate, filename) + + # Create a test scale factor file + scale_filename = tmp_path / "test.txt" + with open(scale_filename, "w", encoding="utf-8") as fp: + fp.write(str(scale_factor)) + + # Call the function and check the output + signal_out, sample_rate_out = read_flac_signal(filename) + + # As a result of the quantization, the signal is not exactly the same + # after encoding and decoding, so I'm changing the tolerance + # for this test + # np.sum(signal_out) = 2190.4271092907347 + # np.sum(signal) = 2190.494140495932 + + assert np.sum(signal_out) == pytest.approx(np.sum(signal), rel=1e-4, abs=1e-4) + assert sample_rate_out == sample_rate From 2674ec7c9826fdab0d2bd7a9c1d035e78ed7dde0 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 17:09:13 +0100 Subject: [PATCH 05/49] new function signal processing Signed-off-by: Gerardo Roa Dabike --- clarity/utils/signal_processing.py | 133 ++++++++++++++++---------- tests/utils/test_signal_processing.py | 37 +++++++ 2 files changed, 120 insertions(+), 50 deletions(-) diff --git a/clarity/utils/signal_processing.py b/clarity/utils/signal_processing.py index cc2051bf5..a9dd9f995 100644 --- a/clarity/utils/signal_processing.py +++ b/clarity/utils/signal_processing.py @@ -8,6 +8,25 @@ from numpy import ndarray +def clip_signal(signal: np.ndarray, soft_clip: bool = False) -> tuple[np.ndarray, int]: + """Clip the signal. + + Args: + signal (np.ndarray): Signal to be clipped and saved. + soft_clip (bool): Whether to use soft clipping. + + Returns: + signal (np.ndarray): Clipped signal. + n_clipped (int): Number of samples clipped. + """ + + if soft_clip: + signal = np.tanh(signal) + n_clipped = np.sum(np.abs(signal) > 1.0) + signal = np.clip(signal, -1.0, 1.0) + return signal, int(n_clipped) + + def compute_rms(signal: ndarray) -> float: """Compute RMS of signal @@ -21,6 +40,62 @@ def compute_rms(signal: ndarray) -> float: return np.sqrt(np.mean(np.square(signal))) +def correlate( + x: np.ndarray, + y: np.ndarray, + mode="full", + method="auto", + lags: int | float | None = None, +) -> np.ndarray: + """ + Wrap of ``scipy.signal.correlate`` that includes a mode + for maxlag. + + This computes the same result as + numpy.correlate(x, y, mode='full')[len(a)-maxlag-1:len(a)+maxlag] + + Args: + x (np.ndarray): First signal + y (np.ndarray): Second signal + mode (str): Mode to pass to ``scipy.signal.correlate`` + method (str): + 'maxlag': Implement cross correlation with a maximum number of lags. + x and y must have the same length. + `mode` is ignored. + based on https://stackoverflow.com/questions/30677241/ + how-to-limit-cross-correlation-window-width-in-numpy + "auto": Run scipy.signal.correlate with method='auto' + 'direct': Run scipy.signal.correlate with method='direct' + 'fft': Run scipy.signal.correlate with method='fft' + lags (int): Maximum number of lags for `method` "maxlag". + Returns: + np.ndarray: cross correlation of x and y + """ + if method == "maxlag": + if lags is None: + raise ValueError("maxlag must be specified for method='maxlag'") + lags = int(lags) + + if x.shape[0] != y.shape[0]: + raise ValueError("x and y must have the same length") + + py = np.pad(y.conj(), 2 * lags, mode="constant") + # pylint: disable=unsubscriptable-object + T = np.lib.stride_tricks.as_strided( + py[2 * lags :], + shape=(2 * lags + 1, len(y) + 2 * lags), + strides=(-py.strides[0], py.strides[0]), + ) + px = np.pad(x, lags, mode="constant") + return T.dot(px) + + if method in ["auto", "direct", "fft"]: + # Run scipy signal correlate with the specified method and mode + return scipy.signal.correlate(x, y, mode=mode, method=method) + + raise ValueError(f"Unknown method: {method}") + + def denormalize_signals(sources: ndarray, ref: ndarray) -> ndarray: """Scale signals back to the original scale. @@ -89,57 +164,15 @@ def resample( raise ValueError(f"Unknown resampling method: {method}") -def correlate( - x: np.ndarray, - y: np.ndarray, - mode="full", - method="auto", - lags: int | float | None = None, -) -> np.ndarray: - """ - Wrap of ``scipy.signal.correlate`` that includes a mode - for maxlag. - - This computes the same result as - numpy.correlate(x, y, mode='full')[len(a)-maxlag-1:len(a)+maxlag] +def to_16bit(signal: np.ndarray) -> np.ndarray: + """Convert the signal to 16 bit. Args: - x (np.ndarray): First signal - y (np.ndarray): Second signal - mode (str): Mode to pass to ``scipy.signal.correlate`` - method (str): - 'maxlag': Implement cross correlation with a maximum number of lags. - x and y must have the same length. - `mode` is ignored. - based on https://stackoverflow.com/questions/30677241/ - how-to-limit-cross-correlation-window-width-in-numpy - "auto": Run scipy.signal.correlate with method='auto' - 'direct': Run scipy.signal.correlate with method='direct' - 'fft': Run scipy.signal.correlate with method='fft' - lags (int): Maximum number of lags for `method` "maxlag". + signal (np.ndarray): Signal to be converted. + Returns: - np.ndarray: cross correlation of x and y + signal (np.ndarray): Converted signal. """ - if method == "maxlag": - if lags is None: - raise ValueError("maxlag must be specified for method='maxlag'") - lags = int(lags) - - if x.shape[0] != y.shape[0]: - raise ValueError("x and y must have the same length") - - py = np.pad(y.conj(), 2 * lags, mode="constant") - # pylint: disable=unsubscriptable-object - T = np.lib.stride_tricks.as_strided( - py[2 * lags :], - shape=(2 * lags + 1, len(y) + 2 * lags), - strides=(-py.strides[0], py.strides[0]), - ) - px = np.pad(x, lags, mode="constant") - return T.dot(px) - - if method in ["auto", "direct", "fft"]: - # Run scipy signal correlate with the specified method and mode - return scipy.signal.correlate(x, y, mode=mode, method=method) - - raise ValueError(f"Unknown method: {method}") + signal = signal * 32768.0 + signal = np.clip(signal, -32768.0, 32767.0) + return signal.astype(np.dtype("int16")) diff --git a/tests/utils/test_signal_processing.py b/tests/utils/test_signal_processing.py index 6f4506293..2e815c771 100644 --- a/tests/utils/test_signal_processing.py +++ b/tests/utils/test_signal_processing.py @@ -4,11 +4,13 @@ import pytest from clarity.utils.signal_processing import ( + clip_signal, compute_rms, correlate, denormalize_signals, normalize_signal, resample, + to_16bit, ) @@ -140,6 +142,41 @@ def test_compute_rms(): ) +@pytest.mark.parametrize( + "signal,soft_clip,expected_output", + [ + ( + np.array([0.5, 2.0, -1.5, 0.8]), + True, + (np.array([0.46211716, 0.96402758, -0.90514825, 0.66403677]), 0), + ), + (np.array([0.5, 2.0, -1.5, 0.8]), False, (np.array([0.5, 1.0, -1.0, 0.8]), 2)), + ], +) +def test_clip_signal(signal, soft_clip, expected_output): + """Test the clip_signal function""" + # Test with soft clip + output = clip_signal(signal, soft_clip=soft_clip) + assert np.allclose(output[0], expected_output[0]) + assert output[1] == expected_output[1] + + +@pytest.mark.parametrize( + "signal,expected_output", + [ + (np.array([0.5, 0.8, 0.2, 1.0]), np.array([16384, 26214, 6553, 32767])), + (np.array([-0.5, -0.8, -0.2, -1.0]), np.array([-16384, -26214, -6553, -32768])), + (np.array([0.5, -0.8, 0.2, -1.0]), np.array([16384, -26214, 6553, -32768])), + ], +) +def test_to_16bit(signal, expected_output): + """Test the to_16bit function""" + # Test with positive signal + output = to_16bit(signal) + print(output) + assert np.allclose(output, expected_output) + + @pytest.mark.parametrize( "input_sample_rate, input_shape, output_sample_rate, output_shape", [ From 509ad17d7c7e18d1d6bb7e8bfbef417a8a9e9265 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 24 May 2023 17:33:19 +0100 Subject: [PATCH 06/49] Update enhance task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/config.yaml | 2 +- recipes/cad1/task1/baseline/enhance.py | 225 ++++++++++++++++-------- 2 files changed, 152 insertions(+), 75 deletions(-) diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index b2de180de..691a01703 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -15,7 +15,7 @@ remix_sample_rate: 32000 # sample rate for output remixed signal nalr: nfir: 220 - fs: ${sample_rate} + sample_rate: ${sample_rate} apply_compressor: False compressor: diff --git a/recipes/cad1/task1/baseline/enhance.py b/recipes/cad1/task1/baseline/enhance.py index 1daf24a92..724584592 100644 --- a/recipes/cad1/task1/baseline/enhance.py +++ b/recipes/cad1/task1/baseline/enhance.py @@ -15,12 +15,19 @@ from omegaconf import DictConfig from scipy.io import wavfile from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB -from torchaudio.transforms import Fade, Resample +from torchaudio.transforms import Fade from clarity.enhancer.compressor import Compressor from clarity.enhancer.nalr import NALR from clarity.utils.audiogram import Audiogram, Listener -from clarity.utils.signal_processing import denormalize_signals, normalize_signal +from clarity.utils.flac_encoder import FlacEncoder +from clarity.utils.signal_processing import ( + clip_signal, + denormalize_signals, + normalize_signal, + resample, + to_16bit, +) from recipes.cad1.task1.baseline.evaluate import make_song_listener_list logger = logging.getLogger(__name__) @@ -28,7 +35,7 @@ def separate_sources( model: torch.nn.Module, - mix: torch.Tensor, + mix: torch.Tensor | ndarray, sample_rate: int, segment: float = 10.0, overlap: float = 0.1, @@ -119,11 +126,11 @@ def get_device(device: str) -> tuple: raise ValueError(f"Unsupported device type: {device}") -def map_to_dict(sources: np.ndarray, sources_list: list[str]) -> dict: +def map_to_dict(sources: ndarray, sources_list: list[str]) -> dict: """Map sources to a dictionary separating audio into left and right channels. Args: - sources (np.ndarray): Signal to be mapped to dictionary. + sources (ndarray): Signal to be mapped to dictionary. sources_list (list): List of strings used to index dictionary. Returns: @@ -142,13 +149,15 @@ def map_to_dict(sources: np.ndarray, sources_list: list[str]) -> dict: # pylint: disable=unused-argument def decompose_signal( - config: DictConfig, model: torch.nn.Module, - signal: np.ndarray, - sample_rate: int, + model_sample_rate: int, + signal: ndarray, + signal_sample_rate: int, device: torch.device, + sources_list: list[str], listener: Listener, -) -> dict[str, np.ndarray]: + normalise: bool = True, +) -> dict[str, ndarray]: """ Decompose signal into 8 stems. @@ -158,47 +167,46 @@ def decompose_signal( HDEMUCS model trained on the MUSDB18 dataset. Args: - config (DictConfig): Configuration object. model (torch.nn.Module): Torch model. - signal (np.ndarray): Signal to be decomposed. - sample_rate (int): Sample frequency. + model_sample_rate (int): Sample rate of the model. + signal (ndarray): Signal to be decomposed. + signal_sample_rate (int): Sample frequency. device (torch.device): Torch device to use for processing. - listener (Listener). + sources_list (list): List of strings used to index dictionary. + listener (Listener): Listener object. + normalise (bool): Whether to normalise the signal. Returns: Dictionary: Indexed by sources with the associated model as values. """ - if config.separator.model == "demucs": - signal, ref = normalize_signal(signal) - - model_sample_rate = ( - model.sample_rate if config.separator.model == "openunmix" else 44100 - ) + # Resample mixture signal to model sample rate + if signal_sample_rate != model_sample_rate: + signal = resample(signal, signal_sample_rate, model_sample_rate) - if sample_rate != model_sample_rate: - resampler = Resample(sample_rate, model_sample_rate) - signal = resampler(signal) + if normalise: + signal, ref = normalize_signal(signal) sources = separate_sources( - model, torch.from_numpy(signal), sample_rate, device=device + model, torch.from_numpy(signal), signal_sample_rate, device=device ) # only one element in the batch sources = sources[0] - if config.separator.model == "demucs": + + if normalise: sources = denormalize_signals(sources, ref) - signal_stems = map_to_dict(sources, config.separator.sources) + signal_stems = map_to_dict(sources, sources_list) return signal_stems def apply_baseline_ha( enhancer: NALR, compressor: Compressor, - signal: np.ndarray, + signal: ndarray, audiogram: Audiogram, apply_compressor: bool = False, -) -> np.ndarray: +) -> ndarray: """ Apply NAL-R prescription hearing aid to a signal. @@ -206,14 +214,12 @@ def apply_baseline_ha( enhancer: A NALR object that enhances the signal. compressor: A Compressor object that compresses the signal. signal: An ndarray representing the audio signal. - listener_audiogram: An ndarray representing the listener's audiogram. - cfs: An ndarray of center frequencies. + audiogram: An Audiogram object representing the listener's audiogram. apply_compressor: A boolean indicating whether to include the compressor. Returns: An ndarray representing the processed signal. """ - print("XXX", audiogram) nalr_fir, _ = enhancer.build(audiogram) proc_signal = enhancer.apply(nalr_fir, signal) if apply_compressor: @@ -234,8 +240,7 @@ def process_stems_for_listener( stems (dict) : Dictionary of stems enhancer (NALR) : NAL-R prescription hearing aid compressor (Compressor) : Compressor - listener: Listener object - cfs (np.ndarray) : Center frequencies + listener (Listener) : Listener object. apply_compressor (bool) : Whether to apply the compressor Returns: processed_sources (dict) : Dictionary of processed stems @@ -261,27 +266,84 @@ def process_stems_for_listener( return processed_stems -def clip_signal(signal: np.ndarray, soft_clip: bool = False) -> tuple[np.ndarray, int]: - """Clip and save the processed stems. +def remix_signal(stems: dict) -> ndarray: + """ + Function to remix signal. It takes the eight stems + and combines them into a stereo signal. Args: - signal (np.ndarray): Signal to be clipped and saved. - soft_clip (bool): Whether to use soft clipping. + stems (dict) : Dictionary of stems Returns: - signal (np.ndarray): Clipped signal. - n_clipped (int): Number of samples clipped. + (ndarray) : Remixed signal + + """ + n_samples = stems[list(stems.keys())[0]].shape[0] + out_left, out_right = np.zeros(n_samples), np.zeros(n_samples) + for stem_str, stem_signal in stems.items(): + if stem_str.startswith("l"): + out_left += stem_signal + else: + out_right += stem_signal + + return np.stack([out_left, out_right], axis=1) + + +def save_flac_signal( + signal: ndarray, + filename: Path, + signal_sample_rate: int, + output_sample_rate: int, + do_clip_signal: bool = False, + do_soft_clip: bool = False, + do_scale_signal: bool = False, +) -> None: """ + Function to save output signals. + + - The output signal will be resample to ``output_sample_rate`` + - The output signal will be clipped to [-1, 1] if ``do_clip_signal`` is True + and use soft clipped if ``do_soft_clip`` is True. Note that if + ``do_clip_signal`` is False, ``do_soft_clip`` will be ignored. + Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored. + - The output signal will be scaled to [-1, 1] if ``do_scale_signal`` is True. + If signal is scale, the scale factor will be saved in a TXT file. + Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored. + - The output signal will be saved as a FLAC file. - if soft_clip: - signal = np.tanh(signal) - n_clipped = np.sum(np.abs(signal) > 1.0) - np.clip(signal, -1.0, 1.0, out=signal) - return signal, int(n_clipped) + Args: + signal (np.ndarray) : Signal to save + filename (Path) : Path to save signal + signal_sample_rate (int) : Sample rate of the input signal + output_sample_rate (int) : Sample rate of the output signal + do_clip_signal (bool) : Whether to clip signal + do_soft_clip (bool) : Whether to apply soft clipping + do_scale_signal (bool) : Whether to scale signal + """ + # Resample signal to expected output sample rate + if signal_sample_rate != output_sample_rate: + signal = resample(signal, signal_sample_rate, output_sample_rate) + + if do_scale_signal: + # Scale stem signal + max_value = np.max(np.abs(signal)) + signal = signal / max_value + + # Save scale factor + with open(filename.with_suffix(".txt"), "w", encoding="utf-8") as file: + file.write(f"{max_value}") + + elif do_clip_signal: + # Clip the signal + signal, n_clipped = clip_signal(signal, do_soft_clip) + if n_clipped > 0: + logger.warning(f"Writing {filename}: {n_clipped} samples clipped") + # Convert signal to 16-bit integer + signal = to_16bit(signal) -def to_16bit(signal: np.ndarray) -> np.ndarray: - return (32768.0 * signal).astype(np.int16) + # Create flac encoder object to compress and save the signal + FlacEncoder().encode(signal, output_sample_rate, filename) @hydra.main(config_path="", config_name="config") @@ -298,6 +360,9 @@ def enhance(config: DictConfig) -> None: - right channel vocal, drums, bass, and other stems """ + if config.separator.model not in ["demucs", "openunmix"]: + raise ValueError(f"Separator model {config.separator.model} not supported.") + enhanced_folder = Path("enhanced_signals") enhanced_folder.mkdir(parents=True, exist_ok=True) @@ -321,8 +386,17 @@ def enhance(config: DictConfig) -> None: if config.separator.model == "demucs": separation_model = HDEMUCS_HIGH_MUSDB.get_model() - else: + model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate + sources_order = separation_model.sources + normalise = True + elif config.separator.model == "openunmix": separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0) + model_sample_rate = separation_model.sample_rate + sources_order = ["vocals", "drums", "bass", "other"] + normalise = False + else: + raise ValueError(f"Separator model {config.separator.model} not supported.") + device, _ = get_device(config.separator.device) separation_model.to(device) @@ -366,9 +440,10 @@ def enhance(config: DictConfig) -> None: else "train" ) - # Read the mixture signal - # Convert to 32-bit floating point and transpose - # from [samples, channels] to [channels, samples] + # Baseline Steps + # 1. Decompose the mixture signal into vocal, drums, bass, and other stems + # We validate if 2 consecutive signals are the same to avoid + # decomposing the same song multiple times if prev_song_name != song_name: # Decompose song only once prev_song_name = song_name @@ -383,17 +458,20 @@ def enhance(config: DictConfig) -> None: assert sample_rate == config.sample_rate stems: dict[str, ndarray] = decompose_signal( - config, separation_model, + model_sample_rate, mixture_signal, sample_rate, device, + sources_order, listener, + normalise, ) - # Baseline applies NALR prescription to each stem instead of using the - # listener's audiograms in the decomposition. This stem can be skipped - # if the listener's audiograms are used in the decomposition + # 2. Apply NAL-R prescription to each stem + # Baseline applies NALR prescription to each stem instead of using the + # listener's audiograms in the decomposition. This step can be skipped + # if the listener's audiograms are used in the decomposition processed_stems = process_stems_for_listener( stems, enhancer, @@ -402,15 +480,8 @@ def enhance(config: DictConfig) -> None: config.apply_compressor, ) - # save processed stems - n_samples = processed_stems[list(processed_stems.keys())[0]].shape[0] - output_left, output_right = np.zeros(n_samples), np.zeros(n_samples) + # 3. Save processed stems for stem_str, stem_signal in processed_stems.items(): - if stem_str.startswith("l"): - output_left += stem_signal - else: - output_right += stem_signal - filename = ( enhanced_folder / f"{listener.id}" @@ -418,26 +489,32 @@ def enhance(config: DictConfig) -> None: / f"{listener.id}_{song_name}_{stem_str}.wav" ) filename.parent.mkdir(parents=True, exist_ok=True) + save_flac_signal( + signal=stem_signal, + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.stem_sample_rate, + do_scale_signal=True, + ) - # Clip and save stem signals - clipped_signal, n_clipped = clip_signal(stem_signal, config.soft_clip) - if n_clipped > 0: - logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - wavfile.write(filename, config.sample_rate, to_16bit(clipped_signal)) + # 4. Remix Signal + enhanced = remix_signal(processed_stems) - enhanced = np.stack([output_left, output_right], axis=1) + # 5. Save enhanced (remixed) signal filename = ( enhanced_folder / f"{listener.id}" / f"{song_name}" - / f"{listener.id}_{song_name}_remix.wav" + / f"{listener.id}_{song_name}_remix.flac" + ) + save_flac_signal( + signal=enhanced, + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.remix_sample_rate, + do_clip_signal=True, + do_soft_clip=config.soft_clip, ) - - # clip and save enhanced signal - clipped_signal, n_clipped = clip_signal(enhanced, config.soft_clip) - if n_clipped > 0: - logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - wavfile.write(filename, config.sample_rate, to_16bit(clipped_signal)) # pylint: disable = no-value-for-parameter From 3e50c11ebe97f1e6776ab14772d523115c2211c0 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:19:08 +0100 Subject: [PATCH 07/49] Update test for enhance task 1 Signed-off-by: Gerardo Roa Dabike --- .../cad1/task1/baseline/test_enhance_task1.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/tests/recipes/cad1/task1/baseline/test_enhance_task1.py b/tests/recipes/cad1/task1/baseline/test_enhance_task1.py index b8cf8e8db..7a5f4e592 100644 --- a/tests/recipes/cad1/task1/baseline/test_enhance_task1.py +++ b/tests/recipes/cad1/task1/baseline/test_enhance_task1.py @@ -1,10 +1,10 @@ """Tests for the enhance module""" +# pylint: disable=import-error from pathlib import Path import numpy as np import pytest import torch -from omegaconf import DictConfig from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB from clarity.enhancer.compressor import Compressor @@ -44,16 +44,17 @@ def test_map_to_dict(): @pytest.mark.parametrize( - "separation_model", + "separation_model,normalise", [ - pytest.param("demucs"), - pytest.param("openunmix", marks=pytest.mark.slow), + (pytest.param("demucs"), True), + (pytest.param("openunmix", marks=pytest.mark.slow), True), ], ) -def test_decompose_signal(separation_model): +def test_decompose_signal(separation_model, normalise): """Takes a signal and decomposes it into VDBO sources using the HDEMUCS model""" np.random.seed(123456789) # Load Separation Model + separation_model = separation_model.values[0] if separation_model == "demucs": model = HDEMUCS_HIGH_MUSDB.get_model().double() elif separation_model == "openunmix": @@ -67,27 +68,19 @@ def test_decompose_signal(separation_model): duration = 0.5 signal = np.random.uniform(size=(1, 2, int(sample_rate * duration))) - # config - config = DictConfig( - { - "sample_rate": sample_rate, - "separator": { - "model": "demucs", - "sources": ["drums", "bass", "other", "vocals"], - }, - } - ) # Call the decompose_signal function and check that the output has the expected keys cfs = np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]) audiogram = Audiogram(levels=np.ones(9), frequencies=cfs) listener = Listener(audiogram, audiogram) output = decompose_signal( - config, - model, - signal, - sample_rate, - device, - listener, + model=model, + model_sample_rate=sample_rate, + signal=signal, + signal_sample_rate=sample_rate, + device=device, + sources_list=["drums", "bass", "other", "vocals"], + listener=listener, + normalise=normalise, ) expected_results = np.load( RESOURCES / f"test_enhance.test_decompose_signal_{separation_model}.npy", From 7ca84cec716e49e1712a3674f5d4fe945374942f Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:36:59 +0100 Subject: [PATCH 08/49] Update evaluate task 1 Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haaqi/haaqi.py | 14 +++--- recipes/cad1/task1/baseline/evaluate.py | 66 ++++++++++++++----------- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/clarity/evaluator/haaqi/haaqi.py b/clarity/evaluator/haaqi/haaqi.py index f4d19662c..a042bfdfa 100644 --- a/clarity/evaluator/haaqi/haaqi.py +++ b/clarity/evaluator/haaqi/haaqi.py @@ -49,8 +49,7 @@ def haaqi_v1( processed (np.ndarray): Output signal with noise, distortion, HA gain, and/or processing. processed_freq (int): Sampling rate in Hz for processed signal. - hearing_loss (np.ndarray): (1,6) vector of hearing loss at the 6 audiometric - frequencies [250, 500, 1000, 2000, 4000, 6000] Hz. + audiogram (Audiogram): Audiogram object with hearing loss levels. equalisation (int): hearing loss equalization mode for reference signal: 1 = no EQ has been provided, the function will add NAL-R 2 = NAL-R EQ has already been added to the reference signal @@ -178,8 +177,9 @@ def haaqi_v1( def compute_haaqi( processed_signal: ndarray, reference_signal: ndarray, + sample_rate_processed: float, + sample_rate_reference: float, audiogram: Audiogram, - sample_rate: float, equalisation: int = 1, level1: float = 65.0, ) -> float: @@ -191,8 +191,10 @@ def compute_haaqi( reference_signal (np.ndarray): Input reference speech signal with no noise or distortion. If a hearing loss is specified, NAL-R equalization is optional + sample_rate_processed (int): Sample rate of processed signal + sample_rate_reference (int): Sample rate of reference signal + audiogram (Audiogram): Audiogram object. - sample_rate (int): Sample rate in Hz. equalisation (int): hearing loss equalization mode for reference signal: 1 = no EQ has been provided, the function will add NAL-R 2 = NAL-R EQ has already been added to the reference signal @@ -211,9 +213,9 @@ def compute_haaqi( score, _, _, _ = haaqi_v1( reference=reference_signal, - reference_freq=sample_rate, + reference_freq=sample_rate_reference, processed=processed_signal, - processed_freq=sample_rate, + processed_freq=sample_rate_processed, audiogram=audiogram, equalisation=equalisation, level1=level1, diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 469513bc6..13f9003f0 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -1,7 +1,6 @@ """Evaluate the enhanced signals using the HAAQI metric.""" from __future__ import annotations -# pylint: disable=too-many-locals # pylint: disable=import-error import csv import hashlib @@ -18,7 +17,11 @@ from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener -from clarity.utils.signal_processing import compute_rms +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.signal_processing import compute_rms, resample + +# pylint: disable=too-many-locals + logger = logging.getLogger(__name__) @@ -129,7 +132,7 @@ def _evaluate_song_listener( Args: song (str): The name of the song to evaluate. - listener (str): The name of the listener to evaluate. + listener (Listener): The listener to evaluate the song for. config (DictConfig): The configuration object. split_dir (str): The name of the split directory. enhanced_folder (Path): The path to the folder containing the enhanced signals. @@ -160,46 +163,53 @@ def _evaluate_song_listener( Path(config.path.music_dir) / split_dir / song / f"{instrument}.wav" ) reference_signal = (reference_signal / 32768.0).astype(np.float32) + reference_signal = resample( + reference_signal, sample_rate_reference_signal, config.stem_sample_rate + ) - # Load enhanced instrument signals - # Load left channel - sample_rate_left_enhanced_signal, left_enhanced_signal = wavfile.read( + # Read left instrument enhanced + left_enhanced_signal, sample_rate_left_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_left_{instrument}.wav" + / f"{listener.id}_{song}_left_{instrument}.flac" ) - left_enhanced_signal = (left_enhanced_signal / 32768.0).astype(np.float32) - # Load right channel - sample_rate_right_enhanced_signal, right_enhanced_signal = wavfile.read( + # Read right instrument enhanced + right_enhanced_signal, sample_rate_right_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_right_{instrument}.wav" + / f"{listener.id}_{song}_right_{instrument}.flac" ) - right_enhanced_signal = (right_enhanced_signal / 32768.0).astype(np.float32) - assert ( - sample_rate_reference_signal - == sample_rate_left_enhanced_signal - == sample_rate_right_enhanced_signal - == config.sample_rate - ) + if sample_rate_left_enhanced_signal != sample_rate_right_enhanced_signal: + raise ValueError( + "The sample rates of the left and right enhanced signals are not " + "the same" + ) + + if sample_rate_reference_signal != config.sample_rate: + raise ValueError( + f"The sample rate of the reference signal is not {config.sample_rate}" + ) + # Compute left and right scores per_instrument_score[f"left_{instrument}"] = compute_haaqi( - left_enhanced_signal, - reference_signal[:, 0], - listener.audiogram_left, - config.sample_rate, - 65 - 20 * np.log10(compute_rms(reference_signal[:, 0])), + processed_signal=left_enhanced_signal, + reference_signal=reference_signal[:, 0], + sample_rate_processed=int(sample_rate_left_enhanced_signal), + sample_rate_reference=config.stem_sample_rate, + audiogram=listener.audiogram_left, + level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 0])), ) per_instrument_score[f"right_{instrument}"] = compute_haaqi( - right_enhanced_signal, - reference_signal[:, 1], - listener.audiogram_right, - config.sample_rate, - 65 - 20 * np.log10(compute_rms(reference_signal[:, 1])), + processed_signal=right_enhanced_signal, + reference_signal=reference_signal[:, 1], + sample_rate_processed=int(sample_rate_right_enhanced_signal), + sample_rate_reference=config.stem_sample_rate, + audiogram=listener.audiogram_right, + level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 1])), ) # Compute the combined score From 5fc650032f060efb92a3700a792005044ec9376e Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:46:50 +0100 Subject: [PATCH 09/49] Update test evaluate task 1 Signed-off-by: Gerardo Roa Dabike --- .../cad1/task1/baseline/test_evaluate.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/recipes/cad1/task1/baseline/test_evaluate.py b/tests/recipes/cad1/task1/baseline/test_evaluate.py index 5b8a06120..7375b6e2a 100644 --- a/tests/recipes/cad1/task1/baseline/test_evaluate.py +++ b/tests/recipes/cad1/task1/baseline/test_evaluate.py @@ -1,4 +1,6 @@ """Tests for the evaluation module""" +# pylint: disable=import-error + from pathlib import Path import numpy as np @@ -7,6 +9,7 @@ from scipy.io import wavfile from clarity.utils.audiogram import Audiogram, Listener +from clarity.utils.flac_encoder import FlacEncoder from recipes.cad1.task1.baseline.evaluate import ( ResultsFile, _evaluate_song_listener, @@ -87,9 +90,10 @@ def test_make_song_listener_list(): "punk_is_not_dead", "my_music_listener", { + "stem_sample_rate": 44100, + "sample_rate": 44100, "evaluate": {"set_random_seed": True}, "path": {"music_dir": None}, - "sample_rate": 44100, "nalr": {"sample_rate": 44100}, }, "test", @@ -101,14 +105,14 @@ def test_make_song_listener_list(): } }, { - "left_drums": 0.205517835, - "right_drums": 0.270553157, - "left_bass": 0.207187220, - "right_bass": 0.205454381, - "left_other": 0.237097711, - "right_other": 0.227505708, - "left_vocals": 0.227105999, - "right_vocals": 0.272616615, + "left_drums": 0.14229280292204488, + "right_drums": 0.15044867874762802, + "left_bass": 0.13337685099485902, + "right_bass": 0.14541734646032817, + "left_other": 0.16310385596493193, + "right_other": 0.1542791489799909, + "left_vocals": 0.12291878218281638, + "right_vocals": 0.13683790592287856, }, ) ], @@ -142,22 +146,22 @@ def test_evaluate_song_listener( instruments = ["drums", "bass", "other", "vocals"] # Create reference and enhanced wav samples + flac_encoder = FlacEncoder() for lr_instrument in list(expected_results.keys()): # enhanced signals are mono enh_file = ( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_{lr_instrument}.wav" + / f"{listener.id}_{song}_{lr_instrument}.flac" ) enh_file.parent.mkdir(exist_ok=True, parents=True) + with open(Path(enh_file).with_suffix(".txt"), "w", encoding="utf-8") as file: + file.write("1.0") # Using very short 100 ms signals to speed up the test - wavfile.write( - enh_file, - 44100, - np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768, - ) + enh_signal = np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768 + flac_encoder.encode(enh_signal.astype(np.int16), 44100, enh_file) for instrument in instruments: # reference signals are stereo @@ -182,7 +186,7 @@ def test_evaluate_song_listener( # Combined score assert isinstance(combined_score, float) assert combined_score == pytest.approx( - 0.231629828, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 0.14358442152193474, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance ) # Per instrument score From 0afebf349756c7cddd72320bf76af347d44d0da8 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:53:47 +0100 Subject: [PATCH 10/49] Update evaluate task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/README.md | 51 ++++++++++++++++---------- recipes/cad1/task1/baseline/enhance.py | 2 +- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/recipes/cad1/task1/baseline/README.md b/recipes/cad1/task1/baseline/README.md index 7732be324..a8f082706 100644 --- a/recipes/cad1/task1/baseline/README.md +++ b/recipes/cad1/task1/baseline/README.md @@ -15,8 +15,8 @@ To download the data, please visit [here](https://forms.gle/UQkuCxqQVxZtGggPA). Alternatively, you can download the MUSDB18-HQ dataset from the official [SigSep website](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav). If you opt for this alternative, be sure to download the uncompressed wav version. Note that you will need both packages to run the baseline system. -If you need additional music data for training your model, please restrict to the use of [MedleyDB](https://medleydb.weebly.com/) [4] [5], -[BACH10](https://labsites.rochester.edu/air/resource.html) [6] and [FMA-small](https://github.com/mdeff/fma) [7]. +If you need additional music data for training your model, please restrict to the use of [MedleyDB](https://medleydb.weebly.com/) [[4](#4-references)] [[5](#4-references)], +[BACH10](https://labsites.rochester.edu/air/resource.html) [[6](#4-references)] and [FMA-small](https://github.com/mdeff/fma) [[7](#4-references)]. Theses are shared as `cadenza_cad1_task1_augmentation_medleydb.tar.gz`, `cadenza_cad1_task1_augmentation_bach10.tar.gz` and `cadenza_cad1_task1_augmentation_fma_small.tar.gz`. **Keeping the augmentation data restricted to these datasets will ensure that the evaluation is fair for all participants**. @@ -56,7 +56,7 @@ cadenza_data ### 1.2 Additional optional data -* **MedleyDB** contains both MedleyDB versions 1 [[4](#references)] and 2 [[5](#references)] datasets. +* **MedleyDB** contains both MedleyDB versions 1 [[4](#4-references)] and 2 [[5](#4-references)] datasets. Tracks from the MedleyDB dataset are not included in the evaluation set. However, is your responsibility to exclude any song that may be already contained in the training set. @@ -70,7 +70,7 @@ cadenza_data └───Metadata ``` -* **BACH10** contains the BACH10 dataset [[6](#references)]. +* **BACH10** contains the BACH10 dataset [[6](#4-references)]. Tracks from the BACH10 dataset are not included in MUSDB18-HQ and can all be used as training augmentation data. @@ -84,7 +84,7 @@ cadenza_data ├───... ``` -* **FMA Small** contains the FMA small subset of the FMA dataset [[7](references)]. +* **FMA Small** contains the FMA small subset of the FMA dataset [[7](#4-references)]. Tracks from the FMA small dataset are not included in the MUSDB18-HQ. This dataset does not provide independent stems but only the full mix. @@ -123,18 +123,26 @@ Note that we use [hydra](https://hydra.cc/docs/intro/) for config handling. ### 2.1 Enhancement -The baseline enhance simply takes the out-of-the-box [Hybrid Demucs](https://github.com/facebookresearch/demucs) [1] +We offer two baseline systems: + +1. Using the out-of-the-box time-domain [Hybrid Demucs](https://github.com/facebookresearch/demucs) [[1](#4-references)] source separation model distributed on [TorchAudio](https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html) -and applies a simple NAL-R [2] fitting amplification to each VDBO (`vocals`, `drums`, `bass` and `others`) stem. +2. Using the out-of-the-box spectrogram-based [Open-Unmix](https://github.com/sigsep/open-unmix-pytorch) +source separation model (version `umxhq`) distributed through [PyTorch Hub](https://pytorch.org/hub/) -The remixing is performed by summing the amplified VDBO stems. +Both system use the same enhancement strategy; using the music separation model, the baseline system estimates the +VDBO (`vocals`, `drums`, `bass` and `others`) stems. Then, they apply a simple NAL-R [[2](#4-references)] fitting amplification to each of them. +These results on eight mono signals (four from the left channel and four from the right channel). Finally, each signal is downsampled to 24000 Hertz, convert to 16bit precision and +encoded using the lossless FLAC compression. These eight signal are then used for the objective evaluation (HAAQI). -The baseline generates a left and right signal for each VDBO stem and a remixed signal, totalling 9 signals per song-listener. +The baselines also provide a remixing strategy to generate a stereo signal for each listener. This is done by summing +the amplified VDBO stems, where each channel (left and right in stereo) is composed of the addition of the corresponding +four stems. This stereo remixed signal is then used for subjective evaluation (listening panel). To run the baseline enhancement system first, make sure that `paths.root` in `config.yaml` points to where you have installed the Cadenza data. This parameter defaults to the working directory. -You can also define your own `path.exp_folder` to store enhanced -signals and evaluated results. +You can also define your own `path.exp_folder` to store the enhanced signals and evaluated results and select what +music separation model you want to employ. Then run: @@ -158,9 +166,8 @@ The folder `enhanced_signals` will appear in the `exp` folder. ### 2.2 Evaluation -The `evaluate.py` simply takes the signals stored in `enhanced_signals` and computes the HAAQI [[3](#references)] score -for each of the eight left and right VDBO stems. -The average of these eight scores is computed and returned for each signal. +The `evaluate.py` script takes the eight VDBO signals stored in `enhanced_signals` and computes the +HAAQI [[3](#4-references)] score. The final score for the sample is the average of the scores of each stem. To run the evaluation stage, make sure that `path.root` is set in the `config.yaml` file and then run @@ -172,13 +179,19 @@ A csv file containing the eight HAAQI scores and the combined score will be gene To check the HAAQI code, see [here](../../../../clarity/evaluator/haaqi). -Please note: you will not get identical HAAQI scores for the same signals if the random seed is not defined -(in the given recipe, the random seed for each signal is set as the last eight digits of the song md5). -As there are random noises generated within HAAQI, but the differences should be sufficiently small. +Please note: you will not get identical HAAQI scores for the same signals if the random seed is not defined. +This is due to the random noises generated within HAAQI, but the differences should be sufficiently small. +For reproducibility, in the given recipe, the random seed for each signal is set as the last eight digits +of the song md5. + +## 3. Results + +The overall HAAQI score for each baseline is: -The score for the baseline is 0.3608 HAAQI overall. +* Demucs: **0.2592** +* Open-Unmix: **0.2273** -## References +## 4. References * [1] Défossez, A. "Hybrid Spectrogram and Waveform Source Separation". Proceedings of the ISMIR 2021 Workshop on Music Source Separation. [doi:10.48550/arXiv.2111.03600](https://arxiv.org/abs/2111.03600) * [2] Byrne, Denis, and Harvey Dillon. "The National Acoustic Laboratories'(NAL) new procedure for selecting the gain and frequency response of a hearing aid." Ear and hearing 7.4 (1986): 257-265. [doi:10.1097/00003446-198608000-00007](https://doi.org/10.1097/00003446-198608000-00007) diff --git a/recipes/cad1/task1/baseline/enhance.py b/recipes/cad1/task1/baseline/enhance.py index 724584592..69d1c1714 100644 --- a/recipes/cad1/task1/baseline/enhance.py +++ b/recipes/cad1/task1/baseline/enhance.py @@ -486,7 +486,7 @@ def enhance(config: DictConfig) -> None: enhanced_folder / f"{listener.id}" / f"{song_name}" - / f"{listener.id}_{song_name}_{stem_str}.wav" + / f"{listener.id}_{song_name}_{stem_str}.flac" ) filename.parent.mkdir(parents=True, exist_ok=True) save_flac_signal( From 95beff1dcc844a9ef5d661a0fe166510639af29c Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:54:28 +0100 Subject: [PATCH 11/49] readme Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/README.md | 55 +++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/recipes/cad1/README.md b/recipes/cad1/README.md index dca206dca..f40609ba6 100644 --- a/recipes/cad1/README.md +++ b/recipes/cad1/README.md @@ -21,37 +21,42 @@ The performance of each system on the validation set is reported below. ### Task 1 - Listening music via headphones -**The overall HAAQI score is 0.3608.** +The overall HAAQI score is: + +- Demucs: **0.2592** +- Open-Unmix: **0.2273** #### Average HAAQI score per song -| Song | HAAQI | -|:------------------------------------------------|:----------:| -| Actions - One Minute Smile | 0.3066 | -| Alexander Ross - Goodbye Bolero | 0.4257 | -| ANiMAL - Rockshow | 0.2389 | -| Clara Berry And Wooldog - Waltz For My Victims | 0.4202 | -| Fergessen - Nos Palpitants | 0.4554 | -| James May - On The Line | 0.3889 | -| Johnny Lokke - Promises & Lies | 0.3395 | -| Leaf - Summerghost | 0.3595 | -| Meaxic - Take A Step | 0.3470 | -| Patrick Talbot - A Reason To Leave | 0.4545 | -| Skelpolu - Human Mistakes | 0.3055 | -| Triviul - Angelsaint | 0.2883 | +| Song | Demucs | Open-UnMix | +|:-----------------------------------------------|:------:|:----------:| +| Actions - One Minute Smile | 0.2485 | 0.2257 | +| Alexander Ross - Goodbye Bolero | 0.3084 | 0.2574 | +| ANiMAL - Rockshow | 0.1843 | 0.1864 | +| Clara Berry And Wooldog - Waltz For My Victims | 0.3094 | 0.2615 | +| Fergessen - Nos Palpitants | 0.3542 | 0.2592 | +| James May - On The Line | 0.2778 | 0.2398 | +| Johnny Lokke - Promises & Lies | 0.2544 | 0.2261 | +| Leaf - Summerghost | 0.2513 | 0.2105 | +| Meaxic - Take A Step | 0.2455 | 0.2239 | +| Patrick Talbot - A Reason To Leave | 0.2673 | 0.2331 | +| Skelpolu - Human Mistakes | 0.2123 | 0.1951 | +| Traffic Experiment - Sirens | 0.2558 | 0.2339 | +| Triviul - Angelsaint | 0.2101 | 0.1955 | +| Young Griffo - Pennies | 0.2499 | 0.2297 | ### Task 2 - Listening music in a car with presence of noise -**The overall HAAQI score is 0.1248.** +**The overall HAAQI score is 0.1423.** #### Average HAAQI score per genre -| Genre | HAAQI | -|:---------------|:----------:| -| Classical | 0.1240 | -| Hip-Hop | 0.1271 | -| Instrumental | 0.1250 | -| International | 0.1267 | -| Orchestral | 0.1121 | -| Pop | 0.1339 | -| Rock | 0.1252 | +| Genre | HAAQI | +|:---------------|:------:| +| Classical | 0.1365 | +| Hip-Hop | 0.1462 | +| Instrumental | 0.1416 | +| International | 0.1432 | +| Orchestral | 0.1329 | +| Pop | 0.1498 | +| Rock | 0.1460 | From 83303487312d214d7cb66dea770f39967076ff18 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:55:39 +0100 Subject: [PATCH 12/49] readme task 2 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/recipes/cad1/task2/baseline/README.md b/recipes/cad1/task2/baseline/README.md index 0dfe7aaaa..bcb910074 100644 --- a/recipes/cad1/task2/baseline/README.md +++ b/recipes/cad1/task2/baseline/README.md @@ -8,8 +8,8 @@ For more information please visit the [challenge website](https://cadenzachallen ### 1.1 Obtaining the CAD1 - Task2 data -The music dataset for the First Cadenza Challenge - Task 2 is based on the small subset of the FMA [2] dataset -(FMA-small) and the MTG-Jamendo dataset [4]. The dataset contains 1000 samples from seven musical genres, +The music dataset for the First Cadenza Challenge - Task 2 is based on the small subset of the FMA [[2](#4-references)] dataset +(FMA-small) and the MTG-Jamendo dataset [[4](#4-references)]. The dataset contains 1000 samples from seven musical genres, totalling 7000 songs with a distribution of 80% / 10% / 10% for `train`, `valid` and `test`. From FMA small: @@ -82,17 +82,18 @@ If you have an Anaconda or Miniconda environment, you can install them as: * conda install -c conda-forge ffmpeg * conda install -c conda-forge libsndfile -```bash - ### 2.1 Enhancement The objective of the enhancement stage is takes a song and optimise it to a listener hearing characteristics -knowing metadata information about the car noise scenario (you won't have access to noise signal), head +knowing metadata information about the car noise scenario (note that you won't have access to noise signal), head rotation of the listener and the SNR of the enhanced music and the noise at the hearing aid microphones. -In the baseline, we simply attenuate the song according to the average hearing loss and save it in 16-bit PCM WAV format. +In the baseline, we attenuate the song according to the average hearing loss. The output are stereo signals +that we save usi ng 32000 Hertz sample rate, 16bit precision, and we encoded it using the lossless FLAC compression. This attenuation prevents some clipping in the hearing aid output signal. +The resulting signals are used for both, the objective (HAAQI) and subjective (listening panel) evaluation. + To run the baseline enhancement system first, make sure that `paths.root` in `config.yaml` points to where you have installed the Cadenza data foer the task2. This parameter defaults to one level above the recipe for the demo data. You can also define your own `path.exp_folder` to store enhanced and evaluated signal results. @@ -120,8 +121,9 @@ The folder `enhanced_signals` will appear in the `exp` folder. ### 2.2 Evaluation The `evaluate.py` module takes the enhanced signals and adds the room impulses and the car noise using -the expected SNR. It then passes that signal through a fixed hearing aid. The hearing aid output and -the reference song are used to compute the HAAQI [2] score. +the expected SNR. It then passes that signal through a fixed hearing aid. The hearing aid is composed of +NAL-R [[1](#4-references)] prescription and compression. The hearing aid output signal and +the reference song are used to compute the HAAQI [[2](#4-references)] score. To run the evaluation stage, make sure that `path.root` is set in the `config.yaml` file and then run @@ -138,9 +140,9 @@ Please note: you will not get identical HAAQI scores for the same signals if the (in the given recipe, the random seed for each signal is set as the last eight digits of the song md5). As there are random noises generated within HAAQI, but the differences should be sufficiently small. -The overall HAAQI score for baseline is 0.1248. +The overall HAAQI score for baseline is **0.1423**. -## References +## 4. References * [1] Byrne, Denis, and Harvey Dillon. "The National Acoustic Laboratories'(NAL) new procedure for selecting the gain and frequency response of a hearing aid." Ear and hearing 7.4 (1986): 257-265. [doi:10.1097/00003446-198608000-00007](https://doi.org/10.1097/00003446-198608000-00007) * [2] Kates J M, Arehart K H. "The Hearing-Aid Audio Quality Index (HAAQI)". IEEE/ACM transactions on audio, speech, and language processing, 24(2), 354–365. [doi:10.1109/TASLP.2015.2507858](https://doi.org/10.1109%2FTASLP.2015.2507858) From 9ed18efdde6bd12a0b3baae72fe5ccff9fcd49a3 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 14:59:54 +0100 Subject: [PATCH 13/49] update utils Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/baseline_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/recipes/cad1/task2/baseline/baseline_utils.py b/recipes/cad1/task2/baseline/baseline_utils.py index fe67c02f3..097787ab2 100644 --- a/recipes/cad1/task2/baseline/baseline_utils.py +++ b/recipes/cad1/task2/baseline/baseline_utils.py @@ -1,6 +1,7 @@ """Utility functions for the baseline model.""" from __future__ import annotations +# pylint: disable=import-error import json import logging import warnings @@ -13,9 +14,6 @@ from clarity.utils.audiogram import Listener -# pylint: disable=import-error - - logger = logging.getLogger(__name__) From 0bbc8f8eeae7471d338c2d57de0539f70ccc4cf2 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 15:01:50 +0100 Subject: [PATCH 14/49] update car scene acoustics Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/car_scene_acoustics.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/recipes/cad1/task2/baseline/car_scene_acoustics.py b/recipes/cad1/task2/baseline/car_scene_acoustics.py index 9fe8cc3e1..8bf5510ac 100644 --- a/recipes/cad1/task2/baseline/car_scene_acoustics.py +++ b/recipes/cad1/task2/baseline/car_scene_acoustics.py @@ -112,9 +112,7 @@ def apply_hearing_aid(self, signal: np.ndarray, audiogram: Audiogram) -> np.ndar Args: signal (np.ndarray): The audio signal to be enhanced. - audiogram (np.ndarray): An audiogram used to configure the NALR object. - center_frequencies (np.ndarray): An array of center frequencies - used to configure the NALR object. + audiogram (Audiogram): The audiogram of the listener. Returns: np.ndarray: The enhanced audio signal. From 883035bbe0adb1f0c0c6db07ffac7468a05d1398 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 15:02:06 +0100 Subject: [PATCH 15/49] update config Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index 36a5ceb2f..6953d32ea 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -11,7 +11,8 @@ path: hrtf_file: ${path.metadata_dir}/eBrird_BRIR.json exp_folder: ./exp # folder to store enhanced signals and final results -sample_rate: 44100 +sample_rate: 44100 # sample rate of the input signal +enhanced_sample_rate: 32000 # sample rate for the enhanced output signal nalr: nfir: 220 From 46032caf09c695d162320a57f4f55723ff48d66f Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 16:37:25 +0100 Subject: [PATCH 16/49] evaluate Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/config.yaml | 2 +- recipes/cad1/task2/baseline/enhance.py | 31 +++++++++-------- recipes/cad1/task2/baseline/evaluate.py | 44 ++++++++++++++++--------- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index 6953d32ea..d2c44eda7 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -16,7 +16,7 @@ enhanced_sample_rate: 32000 # sample rate for the enhanced output signal nalr: nfir: 220 - fs: ${sample_rate} + sample_rate: ${sample_rate} compressor: threshold: 0.7 diff --git a/recipes/cad1/task2/baseline/enhance.py b/recipes/cad1/task2/baseline/enhance.py index eda3ac47e..6c85a6b88 100644 --- a/recipes/cad1/task2/baseline/enhance.py +++ b/recipes/cad1/task2/baseline/enhance.py @@ -12,10 +12,11 @@ import numpy as np import pyloudnorm as pyln from omegaconf import DictConfig -from scipy.io import wavfile from tqdm import tqdm from clarity.utils.audiogram import Listener +from clarity.utils.flac_encoder import FlacEncoder +from clarity.utils.signal_processing import clip_signal, resample, to_16bit from recipes.cad1.task2.baseline.baseline_utils import ( make_scene_listener_list, read_mp3, @@ -30,7 +31,7 @@ def compute_average_hearing_loss(listener: Listener) -> float: Compute the average hearing loss of a listener. Args: - listener (dict): The audiogram of the listener. + listener (Listener): The listener. Returns: average_hearing_loss (float): The average hearing loss of the listener. @@ -56,7 +57,7 @@ def enhance_song( Args: waveform (np.ndarray): The waveform of the song. - listener_dict (dict): The audiograms of the listener. + listener (Listener): The listener. config (dict): Dictionary of configuration options for enhancing music. Returns: @@ -110,6 +111,7 @@ def enhance(config: DictConfig) -> None: config.evaluate.batch :: config.evaluate.batch_size ] + flac_encoder = FlacEncoder() for scene_id, listener_id in tqdm(scene_listener_pairs): current_scene = scenes[scene_id] listener = listener_dict[listener_id] @@ -121,23 +123,26 @@ def enhance(config: DictConfig) -> None: out_l, out_r = enhance_song( waveform=song_waveform, listener=listener, config=config ) - enhanced = np.stack([out_l, out_r], axis=1) - filename = f"{scene_id}_{listener.id}_{current_scene['song']}.wav" + # Save the enhanced song enhanced_folder_listener = enhanced_folder / f"{listener.id}" enhanced_folder_listener.mkdir(parents=True, exist_ok=True) + filename = ( + enhanced_folder_listener + / f"{scene_id}_{listener.id}_{current_scene['song']}.flac" + ) - # Clip and save - if config.soft_clip: - enhanced = np.tanh(enhanced) - n_clipped = np.sum(np.abs(enhanced) > 1.0) + # - Resample to 32 kHz sample rate + # - Clip signal + # - Convert to 16bit + # - Compress using flac + enhanced = resample(enhanced, config.sample_rate, config.enhanced_sample_rate) + clipped_signal, n_clipped = clip_signal(enhanced, config.soft_clip) if n_clipped > 0: logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - np.clip(enhanced, -1.0, 1.0, out=enhanced) - signal_16 = (32768.0 * enhanced).astype(np.int16) - wavfile.write( - enhanced_folder_listener / filename, config.sample_rate, signal_16 + flac_encoder.encode( + to_16bit(clipped_signal), config.enhanced_sample_rate, filename ) diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index b4413ef9f..584ff98b6 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -11,11 +11,12 @@ import hydra import numpy as np from omegaconf import DictConfig -from scipy.io import wavfile from tqdm import tqdm from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.signal_processing import compute_rms, resample from recipes.cad1.task2.baseline.audio_manager import AudioManager from recipes.cad1.task2.baseline.baseline_utils import ( load_hrtf, @@ -194,16 +195,21 @@ def evaluate_scene( # Compute HAAQI scores aq_score_l = compute_haaqi( - processed_signal[0, :], - ref_signal[0, :], - listener.audiogram_left, - sample_rate, + processed_signal=processed_signal[0, :], + reference_signal=ref_signal[0, :], + sample_rate_processed=sample_rate, + sample_rate_reference=sample_rate, + audiogram=listener.audiogram_left, + level1=65 - 20 * np.log10(compute_rms(ref_signal[0, :])), ) + aq_score_r = compute_haaqi( - processed_signal[1, :], - ref_signal[1, :], - listener.audiogram_right, - sample_rate, + processed_signal=processed_signal[1, :], + reference_signal=ref_signal[1, :], + sample_rate_processed=sample_rate, + sample_rate_reference=sample_rate, + audiogram=listener.audiogram_left, + level1=65 - 20 * np.log10(compute_rms(ref_signal[1, :])), ) return aq_score_l, aq_score_r @@ -232,9 +238,10 @@ def run_calculate_audio_quality(config: DictConfig) -> None: results_file.write_header() # Initialize acoustic scene model + sample_rate_haaqi = 24000 car_scene_acoustic = CarSceneAcoustics( track_duration=30, - sample_rate=config.sample_rate, + sample_rate=sample_rate_haaqi, hrtf_dir=config.path.hrtf_dir, config_nalr=config.nalr, config_compressor=config.compressor, @@ -268,15 +275,20 @@ def run_calculate_audio_quality(config: DictConfig) -> None: ) # Read WAV enhanced signal using scipy.io.wavfile - enhanced_sample_rate, enhanced_signal = wavfile.read(enhanced_song_path) - enhanced_signal = enhanced_signal / 32768.0 - assert enhanced_sample_rate == config.sample_rate + enhanced_signal, enhanced_sample_rate = read_flac_signal(enhanced_song_path) + assert enhanced_sample_rate == config.enhanced_sample_rate # Evaluate scene + reference_signal_24k = resample( + reference_signal.T, config.sample_rate, sample_rate_haaqi + ) + enhanced_signal_24k = resample( + enhanced_signal, enhanced_sample_rate, sample_rate_haaqi + ) aq_score_l, aq_score_r = evaluate_scene( - reference_signal, - enhanced_signal.T, - config.sample_rate, + reference_signal_24k.T, + enhanced_signal_24k.T, + sample_rate_haaqi, scene_id, current_scene, listener, From f6134dc37a919bd271959a657bdc4078ee8cbedd Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 25 May 2023 16:42:58 +0100 Subject: [PATCH 17/49] update evaluate task 2 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index 584ff98b6..d5b0214ee 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -271,7 +271,7 @@ def run_calculate_audio_quality(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") / config.evaluate.split enhanced_song_id = f"{scene_id}_{listener.id}_{current_scene['song']}" enhanced_song_path = ( - enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.wav" + enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.flac" ) # Read WAV enhanced signal using scipy.io.wavfile From 9244cfcbae0e50fe5858942fa364c6f2049f14db Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 10:39:13 +0100 Subject: [PATCH 18/49] add pyflac to toml file Signed-off-by: Gerardo Roa Dabike --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5380d9f05..c7fb4a81a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "numpy>=1.21.6", "omegaconf>=2.1.1", "pandas>=1.3.5", + "pyflac", "pyloudnorm>=0.1.0", "pystoi", "pytorch-lightning", From 2a9ca1007b9ffdd0b7221cef0b9e3792763a8284 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 10:58:33 +0100 Subject: [PATCH 19/49] update test haaqi Signed-off-by: Gerardo Roa Dabike --- tests/evaluator/haaqi/test_haaqi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/evaluator/haaqi/test_haaqi.py b/tests/evaluator/haaqi/test_haaqi.py index c38ec5a60..b97fb2626 100644 --- a/tests/evaluator/haaqi/test_haaqi.py +++ b/tests/evaluator/haaqi/test_haaqi.py @@ -1,4 +1,5 @@ """Tests for haaqi module""" +# pylint: disable=import-error import numpy as np import pytest @@ -57,8 +58,9 @@ def test_compute_haaqi(levels, freqs, expected_result): score = compute_haaqi( processed_signal=enh_signal, reference_signal=ref_signal, + sample_rate_processed=sample_rate, + sample_rate_reference=sample_rate, audiogram=audiogram, - sample_rate=sample_rate, ) # Check that the score is a float between 0 and 1 From 852a47bcf6d7478575cb0daed9995d11fb57b62a Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 11:12:00 +0100 Subject: [PATCH 20/49] Add error when segment too small Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haspi/eb.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index a4b0761ab..a9b36a771 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -1804,6 +1804,8 @@ def bm_covary( # if start_sample < 0: # raise ValueError("segment size too small") # win_corr = 1 / win_corr[start_sample:end_sample] + if int(len(window) - 1 - maxlag) < 0: + raise ValueError("segment size too small") win_corr = 1.0 / correlate(window, window, method="maxlag", lags=int(maxlag)) win_sum2 = 1.0 / np.sum(window**2) # Window power, inverted @@ -1816,6 +1818,8 @@ def bm_covary( # if start_sample < 0: # raise ValueError("segment size too small") # half_corr = 1 / half_corr[start_sample:end_sample] + if int(len(half_window) - 1 - maxlag) < 0: + raise ValueError("segment size too small") half_corr = 1.0 / correlate( half_window, half_window, method="maxlag", lags=int(maxlag) ) From 30b686d4a34f2f9e6e1d633b58cab18ef75a3ced Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 11:12:48 +0100 Subject: [PATCH 21/49] update enhanced test Task 2 Signed-off-by: Gerardo Roa Dabike --- tests/recipes/cad1/task2/baseline/test_enhance_task2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/recipes/cad1/task2/baseline/test_enhance_task2.py b/tests/recipes/cad1/task2/baseline/test_enhance_task2.py index 27e31d830..aaa340c46 100644 --- a/tests/recipes/cad1/task2/baseline/test_enhance_task2.py +++ b/tests/recipes/cad1/task2/baseline/test_enhance_task2.py @@ -1,4 +1,6 @@ """Test the enhance module.""" +# pylint: disable=import-error + from pathlib import Path import numpy as np From 34cbe1dc8bb571c9d0340cfdc8471f4b2902eef7 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 11:37:16 +0100 Subject: [PATCH 22/49] small improvement eb Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haspi/eb.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index a9b36a771..da7c96ab4 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -1288,13 +1288,13 @@ def env_smooth(envelopes: np.ndarray, segment_size: int, sample_rate: float) -> """ # Compute the window - n_samples = int( - np.around(segment_size * (0.001 * sample_rate)) - ) # Segment size in samples - test = n_samples - 2 * np.floor(n_samples / 2) # 0=even, 1=odd - if test > 0: - # Force window length to be even - n_samples = n_samples + 1 + # Segment size in samples + n_samples = int(np.around(segment_size * (0.001 * sample_rate))) + n_samples += n_samples % 2 + # test = n_samples - 2 * np.floor(n_samples / 2) # 0=even, 1=odd + # if test > 0: + # # Force window length to be even + # n_samples = n_samples + 1 window = np.hanning(n_samples) # Raised cosine von Hann window wsum = np.sum(window) # Sum for normalization From 726010d3c8dbf00e454ad689bbebd769b54e885e Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 11:49:11 +0100 Subject: [PATCH 23/49] small improvement eb Signed-off-by: Gerardo Roa Dabike --- tests/utils/test_flac_encoder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/utils/test_flac_encoder.py b/tests/utils/test_flac_encoder.py index 85aa1e8f0..96cd45442 100644 --- a/tests/utils/test_flac_encoder.py +++ b/tests/utils/test_flac_encoder.py @@ -74,3 +74,12 @@ def test_read_flac_signal(tmp_path): assert np.sum(signal_out) == pytest.approx(np.sum(signal), rel=1e-4, abs=1e-4) assert sample_rate_out == sample_rate + + +def test_flac_encoder_encode(): + encoder = FlacEncoder(compression_level=5) + # Generate random audio signal + signal = np.random.randint(-32768, 32767, 4410).astype(np.int16) + sample_rate = 44100 + encoded_data = encoder.encode(signal, sample_rate) + assert isinstance(encoded_data, bytes) From 086c421abbee88893a06c20082ad2fc8e9ab9fc2 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Fri, 26 May 2023 14:59:27 +0100 Subject: [PATCH 24/49] change resampling method in librosa Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/baseline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/cad1/task2/baseline/baseline_utils.py b/recipes/cad1/task2/baseline/baseline_utils.py index 097787ab2..07604cf9c 100644 --- a/recipes/cad1/task2/baseline/baseline_utils.py +++ b/recipes/cad1/task2/baseline/baseline_utils.py @@ -38,7 +38,7 @@ def read_mp3( str(file_path), sr=sample_rate, mono=False, - res_type="kaiser_best", + res_type="soxr_hq", dtype=np.float32, ) except Exception as error: From 8592fb2dbef8bd00c9ec8913d57330adf8a92e1b Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 1 Jun 2023 11:10:39 +0100 Subject: [PATCH 25/49] Correcting correlation Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haspi/eb.py | 66 ++++++++++--------------- clarity/utils/signal_processing.py | 56 --------------------- recipes/cad1/task2/baseline/evaluate.py | 4 +- tests/utils/test_signal_processing.py | 57 --------------------- 4 files changed, 28 insertions(+), 155 deletions(-) diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index da7c96ab4..409f8bff5 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -11,6 +11,7 @@ butter, cheby2, convolve, + correlate, firwin, group_delay, lfilter, @@ -19,7 +20,6 @@ from clarity.enhancer.nalr import NALR from clarity.utils.audiogram import Audiogram -from clarity.utils.signal_processing import correlate if TYPE_CHECKING: from numpy import ndarray @@ -1798,31 +1798,23 @@ def bm_covary( window = np.hanning(nwin).conj().transpose() # Raised cosine von Hann window # compute inverted Window autocorrelation - # win_corr = correlate(window, window, "full") - # start_sample = int(len(window) - 1 - maxlag) - # end_sample = int(maxlag + len(window)) - # if start_sample < 0: - # raise ValueError("segment size too small") - # win_corr = 1 / win_corr[start_sample:end_sample] - if int(len(window) - 1 - maxlag) < 0: + win_corr = correlate(window, window, "full") + start_sample = int(len(window) - 1 - maxlag) + end_sample = int(maxlag + len(window)) + if start_sample < 0: raise ValueError("segment size too small") - win_corr = 1.0 / correlate(window, window, method="maxlag", lags=int(maxlag)) + win_corr = 1 / win_corr[start_sample:end_sample] win_sum2 = 1.0 / np.sum(window**2) # Window power, inverted # The first segment has a half window nhalf = int(nwin / 2) half_window = window[nhalf:nwin] - # half_corr = correlate(half_window, half_window, "full") - # start_sample = int(len(half_window) - 1 - maxlag) - # end_sample = int(maxlag + len(half_window)) - # if start_sample < 0: - # raise ValueError("segment size too small") - # half_corr = 1 / half_corr[start_sample:end_sample] - if int(len(half_window) - 1 - maxlag) < 0: + half_corr = correlate(half_window, half_window, "full") + start_sample = int(len(half_window) - 1 - maxlag) + end_sample = int(maxlag + len(half_window)) + if start_sample < 0: raise ValueError("segment size too small") - half_corr = 1.0 / correlate( - half_window, half_window, method="maxlag", lags=int(maxlag) - ) + half_corr = 1 / half_corr[start_sample:end_sample] halfsum2 = 1.0 / np.sum(half_window**2) # MS sum normalization, first segment # Number of segments @@ -1852,13 +1844,11 @@ def bm_covary( ref_mean_square = np.sum(reference_seg**2) * halfsum2 proc_mean_squared = np.sum(processed_seg**2) * halfsum2 - # correlation = correlate(reference_seg, processed_seg, "full") - # correlation = correlation[ - # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - # ] - correlation = correlate( - reference_seg, processed_seg, method="maxlag", lags=int(maxlag) - ) + correlation = correlate(reference_seg, processed_seg, "full") + correlation = correlation[ + int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + ] + unbiased_cross_correlation = np.max(np.abs(correlation * half_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1884,13 +1874,11 @@ def bm_covary( # Normalize signal MS value by the window ref_mean_square = np.sum(reference_seg**2) * win_sum2 proc_mean_squared = np.sum(processed_seg**2) * win_sum2 - # correlation = correlate(reference_seg, processed_seg, "full") - # correlation = correlation[ - # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - # ] - correlation = correlate( - reference_seg, processed_seg, method="maxlag", lags=int(maxlag) - ) + correlation = correlate(reference_seg, processed_seg, "full") + correlation = correlation[ + int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + ] + unbiased_cross_correlation = np.max(np.abs(correlation * win_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1914,13 +1902,11 @@ def bm_covary( ref_mean_square = np.sum(reference_seg**2) * halfsum2 proc_mean_squared = np.sum(processed_seg**2) * halfsum2 - # correlation = correlate(reference_seg, processed_seg, "full") - # correlation = correlation[ - # int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) - # ] - correlation = correlate( - reference_seg, processed_seg, method="maxlag", lags=int(maxlag) - ) + correlation = correlate(reference_seg, processed_seg, "full") + correlation = correlation[ + int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) + ] + unbiased_cross_correlation = np.max(np.abs(correlation * half_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalized cross-covariance diff --git a/clarity/utils/signal_processing.py b/clarity/utils/signal_processing.py index a9dd9f995..dd5f61a63 100644 --- a/clarity/utils/signal_processing.py +++ b/clarity/utils/signal_processing.py @@ -40,62 +40,6 @@ def compute_rms(signal: ndarray) -> float: return np.sqrt(np.mean(np.square(signal))) -def correlate( - x: np.ndarray, - y: np.ndarray, - mode="full", - method="auto", - lags: int | float | None = None, -) -> np.ndarray: - """ - Wrap of ``scipy.signal.correlate`` that includes a mode - for maxlag. - - This computes the same result as - numpy.correlate(x, y, mode='full')[len(a)-maxlag-1:len(a)+maxlag] - - Args: - x (np.ndarray): First signal - y (np.ndarray): Second signal - mode (str): Mode to pass to ``scipy.signal.correlate`` - method (str): - 'maxlag': Implement cross correlation with a maximum number of lags. - x and y must have the same length. - `mode` is ignored. - based on https://stackoverflow.com/questions/30677241/ - how-to-limit-cross-correlation-window-width-in-numpy - "auto": Run scipy.signal.correlate with method='auto' - 'direct': Run scipy.signal.correlate with method='direct' - 'fft': Run scipy.signal.correlate with method='fft' - lags (int): Maximum number of lags for `method` "maxlag". - Returns: - np.ndarray: cross correlation of x and y - """ - if method == "maxlag": - if lags is None: - raise ValueError("maxlag must be specified for method='maxlag'") - lags = int(lags) - - if x.shape[0] != y.shape[0]: - raise ValueError("x and y must have the same length") - - py = np.pad(y.conj(), 2 * lags, mode="constant") - # pylint: disable=unsubscriptable-object - T = np.lib.stride_tricks.as_strided( - py[2 * lags :], - shape=(2 * lags + 1, len(y) + 2 * lags), - strides=(-py.strides[0], py.strides[0]), - ) - px = np.pad(x, lags, mode="constant") - return T.dot(px) - - if method in ["auto", "direct", "fft"]: - # Run scipy signal correlate with the specified method and mode - return scipy.signal.correlate(x, y, mode=mode, method=method) - - raise ValueError(f"Unknown method: {method}") - - def denormalize_signals(sources: ndarray, ref: ndarray) -> ndarray: """Scale signals back to the original scale. diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index d5b0214ee..d42491c73 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -208,7 +208,7 @@ def evaluate_scene( reference_signal=ref_signal[1, :], sample_rate_processed=sample_rate, sample_rate_reference=sample_rate, - audiogram=listener.audiogram_left, + audiogram=listener.audiogram_right, level1=65 - 20 * np.log10(compute_rms(ref_signal[1, :])), ) return aq_score_l, aq_score_r @@ -249,7 +249,7 @@ def run_calculate_audio_quality(config: DictConfig) -> None: ) # Iterate over scenes - for scene_id, listener_id in tqdm(scene_listener_pairs): + for scene_id, listener_id in tqdm(scene_listener_pairs[:1]): current_scene = scenes[scene_id] # Retrieve audiograms diff --git a/tests/utils/test_signal_processing.py b/tests/utils/test_signal_processing.py index 2e815c771..c62b0cd04 100644 --- a/tests/utils/test_signal_processing.py +++ b/tests/utils/test_signal_processing.py @@ -6,7 +6,6 @@ from clarity.utils.signal_processing import ( clip_signal, compute_rms, - correlate, denormalize_signals, normalize_signal, resample, @@ -245,59 +244,3 @@ def test_resample_with_3d_array_error(): resample( signal=input_signal, sample_rate=16000, new_sample_rate=8000, method="soxr" ) - - -def test_correlate_maxlag(): - """Test the function correlate with maxlag""" - x = np.array([1, 2, 3, 4, 5]) - y = np.array([2, 4, 6, 8, 10]) - maxlag = 1 - expected_result = np.array([80, 110, 80]) - - result = correlate(x, y, lags=maxlag, method="maxlag") - print(result) - assert np.sum(result) == pytest.approx( - np.sum(expected_result), rel=pytest.rel_tolerance, abs=pytest.abs_tolerance - ) - - -def test_correlate_auto(): - """Test the function correlate with auto method""" - x = np.array([1, 2, 3, 4, 5]) - y = np.array([2, 4, 6]) - expected_result = np.array([6, 16, 28, 40, 52, 28, 10]) - result = correlate(x, y, method="auto") - - assert np.sum(result) == pytest.approx( - np.sum(expected_result), rel=pytest.rel_tolerance, abs=pytest.abs_tolerance - ) - - -def test_correlate_unknown_method(): - """Test the function correlate with unknown method""" - x = np.array([1, 2, 3, 4, 5]) - y = np.array([2, 4, 6]) - unknown_method = "invalid_method" - - with pytest.raises(ValueError, match="Unknown method: invalid_method"): - correlate(x, y, method=unknown_method) - - -def test_correlate_missing_maxlag(): - """Test the function correlate with missing maxlag""" - x = np.array([1, 2, 3, 4, 5]) - y = np.array([2, 4, 6, 8, 10]) - - with pytest.raises( - ValueError, match="maxlag must be specified for method='maxlag'" - ): - correlate(x, y, method="maxlag") - - -def test_correlate_maxlag_different_length(): - """Test the function correlate with maxlag and different length""" - x = np.array([1, 2, 3, 4, 5]) - y = np.array([2, 4, 6]) - - with pytest.raises(ValueError, match="x and y must have the same length"): - correlate(x, y, lags=1, method="maxlag") From 350891c7a314a0a437d21d97b80e58057b668382 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 1 Jun 2023 12:47:24 +0100 Subject: [PATCH 26/49] Add better loggin Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/evaluate.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 13f9003f0..52c60291f 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -79,7 +79,6 @@ def add_result( instruments_scores (dict): A dictionary of scores for each instrument channel in the result. """ - logger.info(f"The combined score is {score}") with open(self.file_name, "a", encoding="utf-8", newline="") as csv_file: csv_writer = csv.writer( @@ -144,8 +143,6 @@ def _evaluate_song_listener( """ - logger.info(f"Evaluating {song} for {listener.id}") - if config.evaluate.set_random_seed: set_song_seed(song) @@ -156,7 +153,7 @@ def _evaluate_song_listener( "other", "vocals", ]: - logger.info(f"...evaluating {instrument}") + logger.info(f" ...evaluating {instrument}") # Read instrument reference signal sample_rate_reference_signal, reference_signal = wavfile.read( @@ -244,8 +241,14 @@ def run_calculate_aq(config: DictConfig) -> None: song_listener_pair = song_listener_pair[ config.evaluate.batch :: config.evaluate.batch_size ] + num_song_list_pair = len(song_listener_pair) + for idx, song_listener in enumerate(song_listener_pair, 1): + song, listener_id = song_listener + logger.info( + f"[{idx:03d}/{num_song_list_pair:03d}] " + f"Processing {song} for {listener_id}..." + ) - for song, listener_id in song_listener_pair: split_dir = "train" if songs[songs["Track Name"] == song]["Split"].tolist()[0] == "test": split_dir = "test" @@ -257,6 +260,10 @@ def run_calculate_aq(config: DictConfig) -> None: split_dir, enhanced_folder, ) + logger.info( + f"[{idx:03d}/{num_song_list_pair:03d}] " + f"The combined score is {combined_score}" + ) results_file.add_result( listener.id, song, From 2fe148502892be3b6ff2786d2b46c57329e41be0 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 1 Jun 2023 12:56:44 +0100 Subject: [PATCH 27/49] remove filter Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index d42491c73..72be98fe2 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -249,7 +249,7 @@ def run_calculate_audio_quality(config: DictConfig) -> None: ) # Iterate over scenes - for scene_id, listener_id in tqdm(scene_listener_pairs[:1]): + for scene_id, listener_id in tqdm(scene_listener_pairs): current_scene = scenes[scene_id] # Retrieve audiograms From 7d70659db52cbf41e3568c88cb90aa287c7a7a4c Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 1 Jun 2023 13:09:45 +0100 Subject: [PATCH 28/49] remove commented code Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haspi/eb.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index 409f8bff5..415afe2cc 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -1291,10 +1291,7 @@ def env_smooth(envelopes: np.ndarray, segment_size: int, sample_rate: float) -> # Segment size in samples n_samples = int(np.around(segment_size * (0.001 * sample_rate))) n_samples += n_samples % 2 - # test = n_samples - 2 * np.floor(n_samples / 2) # 0=even, 1=odd - # if test > 0: - # # Force window length to be even - # n_samples = n_samples + 1 + window = np.hanning(n_samples) # Raised cosine von Hann window wsum = np.sum(window) # Sum for normalization From 0e4bb425f6fbf18532f0c8fb041f9cb470d14827 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Thu, 1 Jun 2023 13:44:55 +0100 Subject: [PATCH 29/49] add test for errors in Flac_encoder Signed-off-by: Gerardo Roa Dabike --- tests/utils/test_flac_encoder.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/utils/test_flac_encoder.py b/tests/utils/test_flac_encoder.py index 96cd45442..896426fc1 100644 --- a/tests/utils/test_flac_encoder.py +++ b/tests/utils/test_flac_encoder.py @@ -83,3 +83,18 @@ def test_flac_encoder_encode(): sample_rate = 44100 encoded_data = encoder.encode(signal, sample_rate) assert isinstance(encoded_data, bytes) + + +def test_flac_encoder_encode_type_error(): + encoder = FlacEncoder(compression_level=5) + # Generate random audio signal + signal = np.random.randint(-32768, 32767, 4410).astype(np.float32) + sample_rate = 44100 + with pytest.raises(ValueError): + encoder.encode(signal, sample_rate) + + +def test_flac_encoder_decode_fileNotFound(): + encoder = FlacEncoder(compression_level=5) + with pytest.raises(FileNotFoundError): + encoder.decode("fake/path.flac") From 4fbef9919e441d5e4422ac75aac7766eb77ea6e2 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Mon, 4 Sep 2023 12:55:15 +0100 Subject: [PATCH 30/49] correct function name Signed-off-by: Gerardo Roa Dabike --- tests/utils/test_flac_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_flac_encoder.py b/tests/utils/test_flac_encoder.py index 896426fc1..3a3af60f5 100644 --- a/tests/utils/test_flac_encoder.py +++ b/tests/utils/test_flac_encoder.py @@ -94,7 +94,7 @@ def test_flac_encoder_encode_type_error(): encoder.encode(signal, sample_rate) -def test_flac_encoder_decode_fileNotFound(): +def test_flac_encoder_decode_file_not_found(): encoder = FlacEncoder(compression_level=5) with pytest.raises(FileNotFoundError): encoder.decode("fake/path.flac") From f4e6b94b27b01aa00e518f39526e4c485fcaecbf Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Mon, 4 Sep 2023 13:07:16 +0100 Subject: [PATCH 31/49] add test for task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/config.yaml | 4 + recipes/cad1/task1/baseline/test.py | 241 ++++++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 recipes/cad1/task1/baseline/test.py diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index 691a01703..0e9047f38 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -6,8 +6,12 @@ path: music_valid_file: ${path.metadata_dir}/musdb18.valid.json listeners_train_file: ${path.metadata_dir}/listeners.train.json listeners_valid_file: ${path.metadata_dir}/listeners.valid.json + music_test_file: ${path.metadata_dir}/musdb18.test.json + music_segments_test_file: ${path.metadata_dir}/musdb18.segments.test.json + listeners_test_file: ${path.metadata_dir}/listeners.test.json exp_folder: ./exp_${separator.model} # folder to store enhanced signals and final results +team_id: T001 sample_rate: 44100 # sample rate of the input mixture stem_sample_rate: 24000 # sample rate output stems diff --git a/recipes/cad1/task1/baseline/test.py b/recipes/cad1/task1/baseline/test.py new file mode 100644 index 000000000..057da3bb4 --- /dev/null +++ b/recipes/cad1/task1/baseline/test.py @@ -0,0 +1,241 @@ +""" Run the baseline enhancement. """ +from __future__ import annotations + +# pylint: disable=import-error +# pylint: disable=too-many-function-args +import json +import logging +import shutil +from pathlib import Path + +import hydra +import numpy as np +import pandas as pd +import torch +from omegaconf import DictConfig +from scipy.io import wavfile +from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB + +from clarity.enhancer.compressor import Compressor +from clarity.enhancer.nalr import NALR +from clarity.utils.audiogram import Listener +from recipes.cad1.task1.baseline.enhance import ( + decompose_signal, + get_device, + process_stems_for_listener, + remix_signal, + save_flac_signal, +) +from recipes.cad1.task1.baseline.evaluate import make_song_listener_list + +# pylint: disable=too-many-locals + +logger = logging.getLogger(__name__) + + +def pack_submission( + team_id: str, + root_dir: str | Path, + base_dir: str | Path = ".", +) -> None: + """ + Pack the submission files into an archive file. + + Args: + team_id (str): Team ID. + root_dir (str | Path): Root directory of the archived file. + base_dir (str | Path): Base directory to archive. Defaults to ".". + """ + # Pack the submission files + logger.info(f"Packing submission files for team {team_id}...") + shutil.make_archive( + f"submission_{team_id}", + "zip", + root_dir=root_dir, + base_dir=base_dir, + ) + + +@hydra.main(config_path="", config_name="config") +def enhance(config: DictConfig) -> None: + """ + Run the music enhancement. + The system decomposes the music into vocal, drums, bass, and other stems. + Then, the NAL-R prescription procedure is applied to each stem. + Args: + config (dict): Dictionary of configuration options for enhancing music. + + Returns 8 stems for each song: + - left channel vocal, drums, bass, and other stems + - right channel vocal, drums, bass, and other stems + """ + + if config.separator.model not in ["demucs", "openunmix"]: + raise ValueError(f"Separator model {config.separator.model} not supported.") + + enhanced_folder = Path("enhanced_signals") / "evaluation" + enhanced_folder.mkdir(parents=True, exist_ok=True) + + if config.separator.model == "demucs": + separation_model = HDEMUCS_HIGH_MUSDB.get_model() + model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate + sources_order = separation_model.sources + normalise = True + elif config.separator.model == "openunmix": + separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0) + model_sample_rate = separation_model.sample_rate + sources_order = ["vocals", "drums", "bass", "other"] + normalise = False + else: + raise ValueError(f"Separator model {config.separator.model} not supported.") + + device, _ = get_device(config.separator.device) + separation_model.to(device) + + # Processing Validation Set + # Load listener audiograms and songs + listener_dict = Listener.load_listener_dict(config.path.listeners_test_file) + + with open(config.path.music_test_file, encoding="utf-8") as file: + song_data = json.load(file) + songs_details = pd.DataFrame.from_dict(song_data) + + with open(config.path.music_segments_test_file, encoding="utf-8") as file: + songs_segments = json.load(file) + + song_listener_pairs = make_song_listener_list( + songs_details["Track Name"], listener_dict + ) + # Select a batch to process + song_listener_pairs = song_listener_pairs[ + config.evaluate.batch :: config.evaluate.batch_size + ] + + # Create hearing aid objects + enhancer = NALR(**config.nalr) + compressor = Compressor(**config.compressor) + + # Decompose each song into left and right vocal, drums, bass, and other stems + # and process each stem for the listener + prev_song_name = None + num_song_list_pair = len(song_listener_pairs) + for idx, song_listener in enumerate(song_listener_pairs, 1): + song_name, listener_name = song_listener + logger.info( + f"[{idx:03d}/{num_song_list_pair:03d}] " + f"Processing {song_name} for {listener_name}..." + ) + # Get the listener's audiogram + listener = listener_dict[listener_name] + + # Find the music split directory + split_directory = ( + "test" + if songs_details.loc[ + songs_details["Track Name"] == song_name, "Split" + ].iloc[0] + == "test" + else "train" + ) + + # Baseline Steps + # 1. Decompose the mixture signal into vocal, drums, bass, and other stems + # We validate if 2 consecutive signals are the same to avoid + # decomposing the same song multiple times + if prev_song_name != song_name: + # Decompose song only once + prev_song_name = song_name + + sample_rate, mixture_signal = wavfile.read( + Path(config.path.music_dir) + / split_directory + / song_name + / "mixture.wav" + ) + mixture_signal = (mixture_signal / 32768.0).astype(np.float32).T + assert sample_rate == config.sample_rate + + # Decompose mixture signal into stems + stems = decompose_signal( + separation_model, + model_sample_rate, + mixture_signal, + sample_rate, + device, + sources_order, + listener, + normalise, + ) + + # 2. Apply NAL-R prescription to each stem + # Baseline applies NALR prescription to each stem instead of using the + # listener's audiograms in the decomposition. This step can be skipped + # if the listener's audiograms are used in the decomposition + processed_stems = process_stems_for_listener( + stems, + enhancer, + compressor, + listener, + config.apply_compressor, + ) + + # 3. Save processed stems + for stem_str, stem_signal in processed_stems.items(): + filename = ( + enhanced_folder + / f"{listener_name}" + / f"{song_name}" + / f"{listener_name}_{song_name}_{stem_str}.flac" + ) + filename.parent.mkdir(parents=True, exist_ok=True) + start = songs_segments[song_name]["objective_evaluation"]["start"] + end = songs_segments[song_name]["objective_evaluation"]["end"] + save_flac_signal( + signal=stem_signal[ + int(start * config.sample_rate) : int(end * config.sample_rate) + ], + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.stem_sample_rate, + do_scale_signal=True, + ) + + # 3. Remix Signal + enhanced = remix_signal(processed_stems) + + # 5. Save enhanced (remixed) signal + filename = ( + enhanced_folder + / f"{listener.id}" + / f"{song_name}" + / f"{listener.id}_{song_name}_remix.flac" + ) + start = songs_segments[song_name]["subjective_evaluation"]["start"] + end = songs_segments[song_name]["subjective_evaluation"]["end"] + save_flac_signal( + signal=enhanced[ + int(start * config.sample_rate) : int(end * config.sample_rate) + ], + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.remix_sample_rate, + do_clip_signal=True, + do_soft_clip=config.soft_clip, + ) + + pack_submission( + team_id=config.team_id, + root_dir=enhanced_folder.parent, + base_dir=enhanced_folder.name, + ) + + logger.info("Evaluation complete.!!") + logger.info( + f"Please, submit the file submission_{config.team_id}.zip to the challenge " + "using the link provided. Thank you.!!" + ) + + +# pylint: disable = no-value-for-parameter +if __name__ == "__main__": + enhance() From a29dd6335c26d938e873ec793bc3bcc34a8570bf Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Mon, 4 Sep 2023 13:07:48 +0100 Subject: [PATCH 32/49] add test for task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/cad1/task1/baseline/test.py b/recipes/cad1/task1/baseline/test.py index 057da3bb4..0fd4fd98d 100644 --- a/recipes/cad1/task1/baseline/test.py +++ b/recipes/cad1/task1/baseline/test.py @@ -183,9 +183,9 @@ def enhance(config: DictConfig) -> None: for stem_str, stem_signal in processed_stems.items(): filename = ( enhanced_folder - / f"{listener_name}" + / f"{listener.id}" / f"{song_name}" - / f"{listener_name}_{song_name}_{stem_str}.flac" + / f"{listener.id}_{song_name}_{stem_str}.flac" ) filename.parent.mkdir(parents=True, exist_ok=True) start = songs_segments[song_name]["objective_evaluation"]["start"] From a3c11e8b71c40783c93a7dae80cab7fbdd210170 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Mon, 4 Sep 2023 13:10:22 +0100 Subject: [PATCH 33/49] add test for task 2 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/baseline_utils.py | 3 + recipes/cad1/task2/baseline/config.yaml | 3 + recipes/cad1/task2/baseline/test.py | 67 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 recipes/cad1/task2/baseline/test.py diff --git a/recipes/cad1/task2/baseline/baseline_utils.py b/recipes/cad1/task2/baseline/baseline_utils.py index 07604cf9c..6b65406e7 100644 --- a/recipes/cad1/task2/baseline/baseline_utils.py +++ b/recipes/cad1/task2/baseline/baseline_utils.py @@ -99,6 +99,9 @@ def load_listeners_and_scenes( elif config.evaluate.split == "valid": listeners = Listener.load_listener_dict(config.path.listeners_valid_file) scenes = df_scenes[df_scenes["split"] == "valid"].to_dict("index") + elif config.evaluate.split == "test": + listeners = Listener.load_listener_dict(config.path.listeners_test_file) + scenes = df_scenes[df_scenes["split"] == "test"].to_dict("index") else: raise ValueError(f"Unknown split {config.evaluate.split}") diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index d2c44eda7..b6722310a 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -6,11 +6,14 @@ path: hrtf_dir: ${path.audio_dir}/eBrird listeners_train_file: ${path.metadata_dir}/listeners.train.json listeners_valid_file: ${path.metadata_dir}/listeners.valid.json + listeners_test_file: ${path.metadata_dir}/listeners.test.json scenes_file: ${path.metadata_dir}/scenes.json scenes_listeners_file: ${path.metadata_dir}/scenes_listeners.json hrtf_file: ${path.metadata_dir}/eBrird_BRIR.json exp_folder: ./exp # folder to store enhanced signals and final results +team_id: T001 + sample_rate: 44100 # sample rate of the input signal enhanced_sample_rate: 32000 # sample rate for the enhanced output signal diff --git a/recipes/cad1/task2/baseline/test.py b/recipes/cad1/task2/baseline/test.py new file mode 100644 index 000000000..8e5e26638 --- /dev/null +++ b/recipes/cad1/task2/baseline/test.py @@ -0,0 +1,67 @@ +""" Run the dummy enhancement. """ +# pylint: disable=too-many-locals +# pylint: disable=import-error +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +import hydra +from omegaconf import DictConfig + +from recipes.cad1.task2.baseline.enhance import enhance as enhance_set + +logger = logging.getLogger(__name__) + + +def pack_submission( + team_id: str, + root_dir: str | Path, + base_dir: str | Path = ".", +) -> None: + """ + Pack the submission files into an archive file. + + Args: + team_id (str): Team ID. + root_dir (str | Path): Root directory of the archived file. + base_dir (str | Path): Base directory to archive. Defaults to ".". + """ + # Pack the submission files + logger.info(f"Packing submission files for team {team_id}...") + shutil.make_archive( + f"submission_{team_id}", + "zip", + root_dir=root_dir, + base_dir=base_dir, + ) + + +@hydra.main(config_path="", config_name="config") +def enhance(config: DictConfig) -> None: + """ + Run the music enhancement. + The baseline system is a dummy processor that returns the input signal. + + Args: + config (dict): Dictionary of configuration options for enhancing music. + """ + enhance_set(config) + + pack_submission( + team_id=config.team_id, + root_dir=Path("enhanced_signals"), + base_dir=config.evaluate.split, + ) + + logger.info("Evaluation complete.!!") + logger.info( + f"Please, submit the file submission_{config.team_id}.zip to the challenge " + "using the link provided. Thank you.!!" + ) + + +# pylint: disable = no-value-for-parameter +if __name__ == "__main__": + enhance() From 5b6024d3664d39b7b99175a5ff56680bdca38b42 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 12:40:59 +0100 Subject: [PATCH 34/49] enhanced CAD1 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/config.yaml | 10 +++---- recipes/cad1/task1/baseline/enhance.py | 36 ++++++------------------- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index 0e9047f38..f6739a856 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -2,13 +2,9 @@ path: root: ../../cadenza_data_demo/cad1/task1 metadata_dir: ${path.root}/metadata music_dir: ${path.root}/audio/musdb18hq - music_train_file: ${path.metadata_dir}/musdb18.train.json - music_valid_file: ${path.metadata_dir}/musdb18.valid.json - listeners_train_file: ${path.metadata_dir}/listeners.train.json - listeners_valid_file: ${path.metadata_dir}/listeners.valid.json - music_test_file: ${path.metadata_dir}/musdb18.test.json + music_file: ${path.metadata_dir}/musdb18.valid.json + listeners_file: ${path.metadata_dir}/listeners.valid.json music_segments_test_file: ${path.metadata_dir}/musdb18.segments.test.json - listeners_test_file: ${path.metadata_dir}/listeners.test.json exp_folder: ./exp_${separator.model} # folder to store enhanced signals and final results team_id: T001 @@ -29,7 +25,7 @@ compressor: release: 1000 rms_buffer_size: 0.064 -soft_clip: True +soft_clip: False separator: model: demucs # demucs or openunmix diff --git a/recipes/cad1/task1/baseline/enhance.py b/recipes/cad1/task1/baseline/enhance.py index 69d1c1714..610ebd67d 100644 --- a/recipes/cad1/task1/baseline/enhance.py +++ b/recipes/cad1/task1/baseline/enhance.py @@ -366,24 +366,6 @@ def enhance(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") enhanced_folder.mkdir(parents=True, exist_ok=True) - # Training stage - # - # The baseline is using an off-the-shelf model trained on the MUSDB18 dataset - # Training listeners and song are not necessary in this case. - # - # Training songs and audiograms can be read like this: - # - # with open(config.path.listeners_train_file, "r", encoding="utf-8") as file: - # listener_train_audiograms = json.load(file) - # - # with open(config.path.music_train_file, "r", encoding="utf-8") as file: - # song_data = json.load(file) - # songs_train = pd.DataFrame.from_dict(song_data) - # - # train_song_listener_pairs = make_song_listener_list( - # songs_train['Track Name'], listener_train_audiograms - # ) - if config.separator.model == "demucs": separation_model = HDEMUCS_HIGH_MUSDB.get_model() model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate @@ -402,17 +384,15 @@ def enhance(config: DictConfig) -> None: # Processing Validation Set # Load listener audiograms and songs - listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) - with open(config.path.music_valid_file, encoding="utf-8") as file: + with open(config.path.music_file, encoding="utf-8") as file: song_data = json.load(file) - songs_valid = pd.DataFrame.from_dict(song_data) + songs_df = pd.DataFrame.from_dict(song_data) - valid_song_listener_pairs = make_song_listener_list( - songs_valid["Track Name"], listener_dict - ) + song_listener_pairs = make_song_listener_list(songs_df["Track Name"], listener_dict) # Select a batch to process - valid_song_listener_pairs = valid_song_listener_pairs[ + song_listener_pairs = song_listener_pairs[ config.evaluate.batch :: config.evaluate.batch_size ] @@ -422,8 +402,8 @@ def enhance(config: DictConfig) -> None: # Decompose each song into left and right vocal, drums, bass, and other stems # and process each stem for the listener prev_song_name = None - num_song_list_pair = len(valid_song_listener_pairs) - for idx, song_listener in enumerate(valid_song_listener_pairs, 1): + num_song_list_pair = len(song_listener_pairs) + for idx, song_listener in enumerate(song_listener_pairs, 1): song_name, listener_name = song_listener logger.info( f"[{idx:03d}/{num_song_list_pair:03d}] " @@ -435,7 +415,7 @@ def enhance(config: DictConfig) -> None: # Find the music split directory split_directory = ( "test" - if songs_valid.loc[songs_valid["Track Name"] == song_name, "Split"].iloc[0] + if songs_df.loc[songs_df["Track Name"] == song_name, "Split"].iloc[0] == "test" else "train" ) From f6a0c3f30973670cb18a2a7723ffc0b1ee2a0f57 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 12:46:28 +0100 Subject: [PATCH 35/49] evaluate CAD1 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/config.yaml | 2 +- recipes/cad1/task1/baseline/evaluate.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index f6739a856..bfac09086 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -25,7 +25,7 @@ compressor: release: 1000 rms_buffer_size: 0.064 -soft_clip: False +soft_clip: True separator: model: demucs # demucs or openunmix diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 52c60291f..a57995358 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -219,12 +219,12 @@ def _evaluate_song_listener( def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI-RMS metric.""" # Load test songs - with open(config.path.music_valid_file, encoding="utf-8") as fp: + with open(config.path.music_file, encoding="utf-8") as fp: songs = json.load(fp) - songs = pd.DataFrame.from_dict(songs) + songs_df = pd.DataFrame.from_dict(songs) # Load listener data - listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") @@ -235,7 +235,7 @@ def run_calculate_aq(config: DictConfig) -> None: results_file.write_header() song_listener_pair = make_song_listener_list( - songs["Track Name"].tolist(), listener_dict, config.evaluate.small_test + songs_df["Track Name"].tolist(), listener_dict, config.evaluate.small_test ) song_listener_pair = song_listener_pair[ @@ -250,7 +250,7 @@ def run_calculate_aq(config: DictConfig) -> None: ) split_dir = "train" - if songs[songs["Track Name"] == song]["Split"].tolist()[0] == "test": + if songs_df[songs_df["Track Name"] == song]["Split"].tolist()[0] == "test": split_dir = "test" listener = listener_dict[listener_id] combined_score, per_instrument_score = _evaluate_song_listener( From e056f6d905fc9a47165be4eebbb1e32de9e12490 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 12:53:16 +0100 Subject: [PATCH 36/49] update haaqi.py Signed-off-by: Gerardo Roa Dabike --- clarity/evaluator/haaqi/haaqi.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/clarity/evaluator/haaqi/haaqi.py b/clarity/evaluator/haaqi/haaqi.py index a042bfdfa..b3ba3b035 100644 --- a/clarity/evaluator/haaqi/haaqi.py +++ b/clarity/evaluator/haaqi/haaqi.py @@ -49,7 +49,8 @@ def haaqi_v1( processed (np.ndarray): Output signal with noise, distortion, HA gain, and/or processing. processed_freq (int): Sampling rate in Hz for processed signal. - audiogram (Audiogram): Audiogram object with hearing loss levels. + hearing_loss (np.ndarray): (1,6) vector of hearing loss at the 6 audiometric + frequencies [250, 500, 1000, 2000, 4000, 6000] Hz. equalisation (int): hearing loss equalization mode for reference signal: 1 = no EQ has been provided, the function will add NAL-R 2 = NAL-R EQ has already been added to the reference signal @@ -177,8 +178,8 @@ def haaqi_v1( def compute_haaqi( processed_signal: ndarray, reference_signal: ndarray, - sample_rate_processed: float, - sample_rate_reference: float, + processed_sample_rate: float, + reference_sample_rate: float, audiogram: Audiogram, equalisation: int = 1, level1: float = 65.0, @@ -191,9 +192,8 @@ def compute_haaqi( reference_signal (np.ndarray): Input reference speech signal with no noise or distortion. If a hearing loss is specified, NAL-R equalization is optional - sample_rate_processed (int): Sample rate of processed signal - sample_rate_reference (int): Sample rate of reference signal - + processed_sample_rate (float): Sampling rate in Hz for processed signal. + reference_sample_rate (float): Sampling rate in Hz for reference signal. audiogram (Audiogram): Audiogram object. equalisation (int): hearing loss equalization mode for reference signal: 1 = no EQ has been provided, the function will add NAL-R @@ -213,9 +213,9 @@ def compute_haaqi( score, _, _, _ = haaqi_v1( reference=reference_signal, - reference_freq=sample_rate_reference, + reference_freq=reference_sample_rate, processed=processed_signal, - processed_freq=sample_rate_processed, + processed_freq=processed_sample_rate, audiogram=audiogram, equalisation=equalisation, level1=level1, From 631b16b11ec64bfed2e858fa47787525c9ca0397 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 12:54:19 +0100 Subject: [PATCH 37/49] update project.toml Signed-off-by: Gerardo Roa Dabike --- pyproject.toml | 201 +++++++++++++++++++------------------------------ 1 file changed, 78 insertions(+), 123 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7fb4a81a..108243201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,104 +1,92 @@ [build-system] -requires = [ - "setuptools >= 45", - "setuptools_scm[toml]>=6.2", - "wheel", - ] +requires = ["setuptools >= 45", "setuptools_scm[toml]>=6.2", "wheel"] build-backend = "setuptools.build_meta" [project] name = "pyclarity" description = "Tools for the Clarity Challenge" readme = "README.md" -license = {text = "MIT"} +license = { text = "MIT" } dynamic = ["version"] authors = [ - {name = "The PyClarity Team", email = "claritychallengecontact@gmail.com"}, + { name = "The PyClarity Team", email = "claritychallengecontact@gmail.com" }, ] classifiers = [ - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Natural Language :: English", -] -keywords = [ - "hearing", - "signal processing", - "clarity challenge" -] + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Natural Language :: English", +] +keywords = ["hearing", "signal processing", "clarity challenge"] requires-python = ">=3.8" dependencies = [ - "audioread>=2.1.9", - "gdown", - "hydra-core>=1.1.1", - "hydra-submitit-launcher>=1.1.6", - "importlib-metadata", - "librosa>=0.8.1", - "matplotlib", - "numba>=0.57.0rc", - "numpy>=1.21.6", - "omegaconf>=2.1.1", - "pandas>=1.3.5", - "pyflac", - "pyloudnorm>=0.1.0", - "pystoi", - "pytorch-lightning", - "resampy", - "scikit-learn>=1.0.2", - "scipy>=1.7.3", - "SoundFile>=0.10.3.post1", - "soxr", - "torch>=2", - "torchaudio", - "tqdm>=4.62.3", - "typing_extensions", + "audioread>=2.1.9", + "gdown", + "hydra-core>=1.1.1", + "hydra-submitit-launcher>=1.1.6", + "importlib-metadata", + "librosa>=0.8.1", + "matplotlib", + "numba>=0.57.0rc", + "numpy>=1.21.6", + "omegaconf>=2.1.1", + "pandas>=1.3.5", + "pyflac", + "pyloudnorm>=0.1.0", + "pystoi", + "pytorch-lightning", + "resampy", + "scikit-learn>=1.0.2", + "scipy>=1.7.3", + "SoundFile>=0.10.3.post1", + "soxr", + "torch>=2", + "torchaudio", + "tqdm>=4.62.3", + "typing_extensions", ] [project.optional-dependencies] tests = [ - "coverage", - "isort", - "flake8", - "flake8-print", - "Flake8-pyproject", - "mypy", - "pre-commit", - "py", - "py-cpuinfo", - "pytest", - "pytest-cov", - "pytest-mock", - "pytest-mpl", - "pytest-regtest", - "pytest-skip-slow", - "pytest-xdist", - "yamllint", -] -docs =[ - "sphinx", - "myst_parser", - "pydata_sphinx_theme", - "sphinx_markdown_tables", - "sphinx_rtd_theme", - "sphinxcontrib-mermaid", - "sphinxcontrib-napoleon", + "coverage", + "isort", + "flake8", + "flake8-print", + "Flake8-pyproject", + "mypy", + "pre-commit", + "py", + "py-cpuinfo", + "pytest", + "pytest-cov", + "pytest-mock", + "pytest-mpl", + "pytest-regtest", + "pytest-skip-slow", + "pytest-xdist", + "yamllint", +] +docs = [ + "sphinx", + "myst_parser", + "pydata_sphinx_theme", + "sphinx_markdown_tables", + "sphinx_rtd_theme", + "sphinxcontrib-mermaid", + "sphinxcontrib-napoleon", ] dev = [ - "black", - "pre-commit", - "pycodestyle", - "pylint", - "pylint-pytest", - "yamllint", -] -pypi =[ - "build", - "wheel", - "setuptools_scm[toml]" -] + "black", + "pre-commit", + "pycodestyle", + "pylint", + "pylint-pytest", + "yamllint", +] +pypi = ["build", "wheel", "setuptools_scm[toml]"] [project.urls] Source = "https://github.com/claritychallenge/clarity" @@ -118,17 +106,8 @@ exclude = ["tests*"] namespaces = false [tool.setuptools.package-data] -clarity = [ - "*.json", - "*.mat", - "*.yaml" -] -recipes = [ - "*.csv", - "*.json", - "*.mat", - "*.yaml" -] +clarity = ["*.json", "*.mat", "*.yaml"] +recipes = ["*.csv", "*.json", "*.mat", "*.yaml"] [tool.setuptools_scm] @@ -140,21 +119,12 @@ git_describe_command = "git describe --tags" [tool.pytest.ini_options] minversion = "7.0" addopts = "--cov clarity" -testpaths = [ - "tests", -] -filterwarnings = [ - "ignore::UserWarning" -] +testpaths = ["tests"] +filterwarnings = ["ignore::UserWarning"] [tool.coverage.run] source = ["clarity"] -omit = [ - "*conftest.py", - "*tests*", - "**/__init__*", - "clarity/_version.py", -] +omit = ["*conftest.py", "*tests*", "**/__init__*", "clarity/_version.py"] [tool.black] line-length = 88 @@ -171,23 +141,11 @@ exclude = ''' ''' [flake8] -ignore =[ - "E203", - "E501", - "W503", -] +ignore = ["E203", "E501", "W503"] # docstring-convention = "" max-line-length = 88 max-complexity = 18 -select = [ - "B", - "C", - "E", - "F", - "W", - "T4", - "B9" -] +select = ["B", "C", "E", "F", "W", "T4", "B9"] [tool.ruff] exclude = [ @@ -221,7 +179,4 @@ fixable = ["A", "B", "C", "D", "E", "F", "R", "S", "W", "U"] [tool.mypy] ignore_missing_imports = true -exclude = [ - "docs/*", - "build/*" -] +exclude = ["docs/*", "build/*"] \ No newline at end of file From 793623bb661d1b76367a8c6622ddc8959935a39b Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 12:58:56 +0100 Subject: [PATCH 38/49] tests Signed-off-by: Gerardo Roa Dabike --- tests/evaluator/haaqi/test_haaqi.py | 4 +- .../cad1/task1/baseline/test_evaluate.py | 37 +++++++++---------- tests/utils/test_flac_encoder.py | 24 ------------ 3 files changed, 19 insertions(+), 46 deletions(-) diff --git a/tests/evaluator/haaqi/test_haaqi.py b/tests/evaluator/haaqi/test_haaqi.py index b97fb2626..bf29363e0 100644 --- a/tests/evaluator/haaqi/test_haaqi.py +++ b/tests/evaluator/haaqi/test_haaqi.py @@ -58,8 +58,8 @@ def test_compute_haaqi(levels, freqs, expected_result): score = compute_haaqi( processed_signal=enh_signal, reference_signal=ref_signal, - sample_rate_processed=sample_rate, - sample_rate_reference=sample_rate, + processed_sample_rate=sample_rate, + reference_sample_rate=sample_rate, audiogram=audiogram, ) diff --git a/tests/recipes/cad1/task1/baseline/test_evaluate.py b/tests/recipes/cad1/task1/baseline/test_evaluate.py index 7375b6e2a..544094eb4 100644 --- a/tests/recipes/cad1/task1/baseline/test_evaluate.py +++ b/tests/recipes/cad1/task1/baseline/test_evaluate.py @@ -1,15 +1,13 @@ """Tests for the evaluation module""" -# pylint: disable=import-error - from pathlib import Path +# pylint: disable=import-error import numpy as np import pytest from omegaconf import DictConfig from scipy.io import wavfile from clarity.utils.audiogram import Audiogram, Listener -from clarity.utils.flac_encoder import FlacEncoder from recipes.cad1.task1.baseline.evaluate import ( ResultsFile, _evaluate_song_listener, @@ -90,10 +88,9 @@ def test_make_song_listener_list(): "punk_is_not_dead", "my_music_listener", { - "stem_sample_rate": 44100, - "sample_rate": 44100, "evaluate": {"set_random_seed": True}, "path": {"music_dir": None}, + "sample_rate": 44100, "nalr": {"sample_rate": 44100}, }, "test", @@ -105,14 +102,14 @@ def test_make_song_listener_list(): } }, { - "left_drums": 0.14229280292204488, - "right_drums": 0.15044867874762802, - "left_bass": 0.13337685099485902, - "right_bass": 0.14541734646032817, - "left_other": 0.16310385596493193, - "right_other": 0.1542791489799909, - "left_vocals": 0.12291878218281638, - "right_vocals": 0.13683790592287856, + "left_drums": 0.14229422779265366, + "right_drums": 0.15044965630960655, + "left_bass": 0.1333774836344767, + "right_bass": 0.14541827476097585, + "left_other": 0.16310480582621734, + "right_other": 0.15427835764875864, + "left_vocals": 0.12291980372806624, + "right_vocals": 0.1368378217706031, }, ) ], @@ -146,22 +143,22 @@ def test_evaluate_song_listener( instruments = ["drums", "bass", "other", "vocals"] # Create reference and enhanced wav samples - flac_encoder = FlacEncoder() for lr_instrument in list(expected_results.keys()): # enhanced signals are mono enh_file = ( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_{lr_instrument}.flac" + / f"{listener.id}_{song}_{lr_instrument}.wav" ) enh_file.parent.mkdir(exist_ok=True, parents=True) - with open(Path(enh_file).with_suffix(".txt"), "w", encoding="utf-8") as file: - file.write("1.0") # Using very short 100 ms signals to speed up the test - enh_signal = np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768 - flac_encoder.encode(enh_signal.astype(np.int16), 44100, enh_file) + wavfile.write( + enh_file, + 44100, + np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768, + ) for instrument in instruments: # reference signals are stereo @@ -186,7 +183,7 @@ def test_evaluate_song_listener( # Combined score assert isinstance(combined_score, float) assert combined_score == pytest.approx( - 0.14358442152193474, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 0.14358505393391977, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance ) # Per instrument score diff --git a/tests/utils/test_flac_encoder.py b/tests/utils/test_flac_encoder.py index 3a3af60f5..85aa1e8f0 100644 --- a/tests/utils/test_flac_encoder.py +++ b/tests/utils/test_flac_encoder.py @@ -74,27 +74,3 @@ def test_read_flac_signal(tmp_path): assert np.sum(signal_out) == pytest.approx(np.sum(signal), rel=1e-4, abs=1e-4) assert sample_rate_out == sample_rate - - -def test_flac_encoder_encode(): - encoder = FlacEncoder(compression_level=5) - # Generate random audio signal - signal = np.random.randint(-32768, 32767, 4410).astype(np.int16) - sample_rate = 44100 - encoded_data = encoder.encode(signal, sample_rate) - assert isinstance(encoded_data, bytes) - - -def test_flac_encoder_encode_type_error(): - encoder = FlacEncoder(compression_level=5) - # Generate random audio signal - signal = np.random.randint(-32768, 32767, 4410).astype(np.float32) - sample_rate = 44100 - with pytest.raises(ValueError): - encoder.encode(signal, sample_rate) - - -def test_flac_encoder_decode_file_not_found(): - encoder = FlacEncoder(compression_level=5) - with pytest.raises(FileNotFoundError): - encoder.decode("fake/path.flac") From e33363072b08133f078e0127e2c57fbe1ea5b021 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:00:31 +0100 Subject: [PATCH 39/49] toml Signed-off-by: Gerardo Roa Dabike --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 108243201..1985ad884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,7 @@ [build-system] -requires = ["setuptools >= 45", "setuptools_scm[toml]>=6.2", "wheel"] +requires = ["setuptools >= 45", + "setuptools_scm[toml]>=6.2", + "wheel"] build-backend = "setuptools.build_meta" [project] From 532d4c5a56c8104248f878d717cfdb05d980760a Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:00:44 +0100 Subject: [PATCH 40/49] toml Signed-off-by: Gerardo Roa Dabike --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1985ad884..108243201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,5 @@ [build-system] -requires = ["setuptools >= 45", - "setuptools_scm[toml]>=6.2", - "wheel"] +requires = ["setuptools >= 45", "setuptools_scm[toml]>=6.2", "wheel"] build-backend = "setuptools.build_meta" [project] From cbc73bf9c56b1acbd01fec6be805b776510765f9 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:16:14 +0100 Subject: [PATCH 41/49] evaluate task 1 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/evaluate.py | 93 +++++++++++-------------- 1 file changed, 40 insertions(+), 53 deletions(-) diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index a57995358..7c41dd285 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -1,6 +1,7 @@ """Evaluate the enhanced signals using the HAAQI metric.""" from __future__ import annotations +# pylint: disable=too-many-locals # pylint: disable=import-error import csv import hashlib @@ -17,11 +18,7 @@ from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener -from clarity.utils.flac_encoder import read_flac_signal -from clarity.utils.signal_processing import compute_rms, resample - -# pylint: disable=too-many-locals - +from clarity.utils.signal_processing import compute_rms logger = logging.getLogger(__name__) @@ -79,6 +76,7 @@ def add_result( instruments_scores (dict): A dictionary of scores for each instrument channel in the result. """ + logger.info(f"The combined score is {score}") with open(self.file_name, "a", encoding="utf-8", newline="") as csv_file: csv_writer = csv.writer( @@ -131,7 +129,7 @@ def _evaluate_song_listener( Args: song (str): The name of the song to evaluate. - listener (Listener): The listener to evaluate the song for. + listener (str): The name of the listener to evaluate. config (DictConfig): The configuration object. split_dir (str): The name of the split directory. enhanced_folder (Path): The path to the folder containing the enhanced signals. @@ -143,6 +141,8 @@ def _evaluate_song_listener( """ + logger.info(f"Evaluating {song} for {listener.id}") + if config.evaluate.set_random_seed: set_song_seed(song) @@ -153,59 +153,56 @@ def _evaluate_song_listener( "other", "vocals", ]: - logger.info(f" ...evaluating {instrument}") + logger.info(f"...evaluating {instrument}") # Read instrument reference signal sample_rate_reference_signal, reference_signal = wavfile.read( Path(config.path.music_dir) / split_dir / song / f"{instrument}.wav" ) reference_signal = (reference_signal / 32768.0).astype(np.float32) - reference_signal = resample( - reference_signal, sample_rate_reference_signal, config.stem_sample_rate - ) - # Read left instrument enhanced - left_enhanced_signal, sample_rate_left_enhanced_signal = read_flac_signal( + # Load enhanced instrument signals + # Load left channel + sample_rate_left_enhanced_signal, left_enhanced_signal = wavfile.read( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_left_{instrument}.flac" + / f"{listener.id}_{song}_left_{instrument}.wav" ) + left_enhanced_signal = (left_enhanced_signal / 32768.0).astype(np.float32) - # Read right instrument enhanced - right_enhanced_signal, sample_rate_right_enhanced_signal = read_flac_signal( + # Load right channel + sample_rate_right_enhanced_signal, right_enhanced_signal = wavfile.read( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_right_{instrument}.flac" + / f"{listener.id}_{song}_right_{instrument}.wav" ) + right_enhanced_signal = (right_enhanced_signal / 32768.0).astype(np.float32) - if sample_rate_left_enhanced_signal != sample_rate_right_enhanced_signal: - raise ValueError( - "The sample rates of the left and right enhanced signals are not " - "the same" - ) - - if sample_rate_reference_signal != config.sample_rate: - raise ValueError( - f"The sample rate of the reference signal is not {config.sample_rate}" - ) + assert ( + sample_rate_reference_signal + == sample_rate_left_enhanced_signal + == sample_rate_right_enhanced_signal + == config.sample_rate + ) - # Compute left and right scores per_instrument_score[f"left_{instrument}"] = compute_haaqi( - processed_signal=left_enhanced_signal, - reference_signal=reference_signal[:, 0], - sample_rate_processed=int(sample_rate_left_enhanced_signal), - sample_rate_reference=config.stem_sample_rate, - audiogram=listener.audiogram_left, + left_enhanced_signal, + reference_signal[:, 0], + config.sample_rate, + config.sample_rate, + listener.audiogram_left, + equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 0])), ) per_instrument_score[f"right_{instrument}"] = compute_haaqi( - processed_signal=right_enhanced_signal, - reference_signal=reference_signal[:, 1], - sample_rate_processed=int(sample_rate_right_enhanced_signal), - sample_rate_reference=config.stem_sample_rate, - audiogram=listener.audiogram_right, + right_enhanced_signal, + reference_signal[:, 1], + config.sample_rate, + config.sample_rate, + listener.audiogram_right, + equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 1])), ) @@ -219,12 +216,12 @@ def _evaluate_song_listener( def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI-RMS metric.""" # Load test songs - with open(config.path.music_file, encoding="utf-8") as fp: + with open(config.path.music_valid_file, encoding="utf-8") as fp: songs = json.load(fp) - songs_df = pd.DataFrame.from_dict(songs) + songs = pd.DataFrame.from_dict(songs) # Load listener data - listener_dict = Listener.load_listener_dict(config.path.listeners_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") @@ -235,22 +232,16 @@ def run_calculate_aq(config: DictConfig) -> None: results_file.write_header() song_listener_pair = make_song_listener_list( - songs_df["Track Name"].tolist(), listener_dict, config.evaluate.small_test + songs["Track Name"].tolist(), listener_dict, config.evaluate.small_test ) song_listener_pair = song_listener_pair[ config.evaluate.batch :: config.evaluate.batch_size ] - num_song_list_pair = len(song_listener_pair) - for idx, song_listener in enumerate(song_listener_pair, 1): - song, listener_id = song_listener - logger.info( - f"[{idx:03d}/{num_song_list_pair:03d}] " - f"Processing {song} for {listener_id}..." - ) + for song, listener_id in song_listener_pair: split_dir = "train" - if songs_df[songs_df["Track Name"] == song]["Split"].tolist()[0] == "test": + if songs[songs["Track Name"] == song]["Split"].tolist()[0] == "test": split_dir = "test" listener = listener_dict[listener_id] combined_score, per_instrument_score = _evaluate_song_listener( @@ -260,10 +251,6 @@ def run_calculate_aq(config: DictConfig) -> None: split_dir, enhanced_folder, ) - logger.info( - f"[{idx:03d}/{num_song_list_pair:03d}] " - f"The combined score is {combined_score}" - ) results_file.add_result( listener.id, song, From 29acbf75d6040f98f80fc9c5d6e8732c0f75e093 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:17:07 +0100 Subject: [PATCH 42/49] evaluate task 2 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/evaluate.py | 48 ++++++++++--------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index 72be98fe2..a434b3f30 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -11,12 +11,11 @@ import hydra import numpy as np from omegaconf import DictConfig +from scipy.io import wavfile from tqdm import tqdm from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener -from clarity.utils.flac_encoder import read_flac_signal -from clarity.utils.signal_processing import compute_rms, resample from recipes.cad1.task2.baseline.audio_manager import AudioManager from recipes.cad1.task2.baseline.baseline_utils import ( load_hrtf, @@ -195,21 +194,18 @@ def evaluate_scene( # Compute HAAQI scores aq_score_l = compute_haaqi( - processed_signal=processed_signal[0, :], - reference_signal=ref_signal[0, :], - sample_rate_processed=sample_rate, - sample_rate_reference=sample_rate, - audiogram=listener.audiogram_left, - level1=65 - 20 * np.log10(compute_rms(ref_signal[0, :])), + processed_signal[0, :], + ref_signal[0, :], + sample_rate, + sample_rate, + listener.audiogram_left, ) - aq_score_r = compute_haaqi( - processed_signal=processed_signal[1, :], - reference_signal=ref_signal[1, :], - sample_rate_processed=sample_rate, - sample_rate_reference=sample_rate, - audiogram=listener.audiogram_right, - level1=65 - 20 * np.log10(compute_rms(ref_signal[1, :])), + processed_signal[1, :], + ref_signal[1, :], + sample_rate, + sample_rate, + listener.audiogram_right, ) return aq_score_l, aq_score_r @@ -238,10 +234,9 @@ def run_calculate_audio_quality(config: DictConfig) -> None: results_file.write_header() # Initialize acoustic scene model - sample_rate_haaqi = 24000 car_scene_acoustic = CarSceneAcoustics( track_duration=30, - sample_rate=sample_rate_haaqi, + sample_rate=config.sample_rate, hrtf_dir=config.path.hrtf_dir, config_nalr=config.nalr, config_compressor=config.compressor, @@ -271,24 +266,19 @@ def run_calculate_audio_quality(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") / config.evaluate.split enhanced_song_id = f"{scene_id}_{listener.id}_{current_scene['song']}" enhanced_song_path = ( - enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.flac" + enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.wav" ) # Read WAV enhanced signal using scipy.io.wavfile - enhanced_signal, enhanced_sample_rate = read_flac_signal(enhanced_song_path) - assert enhanced_sample_rate == config.enhanced_sample_rate + enhanced_sample_rate, enhanced_signal = wavfile.read(enhanced_song_path) + enhanced_signal = enhanced_signal / 32768.0 + assert enhanced_sample_rate == config.sample_rate # Evaluate scene - reference_signal_24k = resample( - reference_signal.T, config.sample_rate, sample_rate_haaqi - ) - enhanced_signal_24k = resample( - enhanced_signal, enhanced_sample_rate, sample_rate_haaqi - ) aq_score_l, aq_score_r = evaluate_scene( - reference_signal_24k.T, - enhanced_signal_24k.T, - sample_rate_haaqi, + reference_signal, + enhanced_signal.T, + config.sample_rate, scene_id, current_scene, listener, From ab855f2ba3ba8a587e4df1bd843929990aadae7e Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:19:11 +0100 Subject: [PATCH 43/49] toml Signed-off-by: Gerardo Roa Dabike --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 108243201..1f9fa9d24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,4 +179,4 @@ fixable = ["A", "B", "C", "D", "E", "F", "R", "S", "W", "U"] [tool.mypy] ignore_missing_imports = true -exclude = ["docs/*", "build/*"] \ No newline at end of file +exclude = ["docs/*", "build/*"] From e21c139b9b6fc70ae428e062f543eee9a715aef1 Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:38:22 +0100 Subject: [PATCH 44/49] evaluate cad1 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/evaluate.py | 91 ++++++++++++++----------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 7c41dd285..4798c2f32 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -18,7 +18,8 @@ from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener -from clarity.utils.signal_processing import compute_rms +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.signal_processing import compute_rms, resample logger = logging.getLogger(__name__) @@ -161,47 +162,61 @@ def _evaluate_song_listener( ) reference_signal = (reference_signal / 32768.0).astype(np.float32) - # Load enhanced instrument signals - # Load left channel - sample_rate_left_enhanced_signal, left_enhanced_signal = wavfile.read( + # Read left instrument enhanced + left_enhanced_signal, sample_rate_left_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_left_{instrument}.wav" + / f"{listener.id}_{song}_left_{instrument}.flac" ) - left_enhanced_signal = (left_enhanced_signal / 32768.0).astype(np.float32) - # Load right channel - sample_rate_right_enhanced_signal, right_enhanced_signal = wavfile.read( + # Read right instrument enhanced + right_enhanced_signal, sample_rate_right_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_right_{instrument}.wav" + / f"{listener.id}_{song}_right_{instrument}.flac" ) - right_enhanced_signal = (right_enhanced_signal / 32768.0).astype(np.float32) - assert ( - sample_rate_reference_signal - == sample_rate_left_enhanced_signal - == sample_rate_right_enhanced_signal - == config.sample_rate - ) + if ( + sample_rate_left_enhanced_signal + != sample_rate_right_enhanced_signal + != config.stem_sample_rate + ): + raise ValueError( + "The sample rates of the left and right enhanced signals are not " + "the same" + ) + + if sample_rate_reference_signal != config.sample_rate: + raise ValueError( + f"The sample rate of the reference signal is not {config.sample_rate}" + ) per_instrument_score[f"left_{instrument}"] = compute_haaqi( - left_enhanced_signal, - reference_signal[:, 0], - config.sample_rate, - config.sample_rate, - listener.audiogram_left, + processed_signal=left_enhanced_signal, + reference_signal=resample( + reference_signal[:, 0], + sample_rate_reference_signal, + config.stem_sample_rate, + ), + processed_sample_rate=config.stem_sample_rate, + reference_sample_rate=config.stem_sample_rate, + audiogram=listener.audiogram_left, equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 0])), ) + per_instrument_score[f"right_{instrument}"] = compute_haaqi( - right_enhanced_signal, - reference_signal[:, 1], - config.sample_rate, - config.sample_rate, - listener.audiogram_right, + processed_signal=right_enhanced_signal, + reference_signal=resample( + reference_signal[:, 1], + sample_rate_reference_signal, + config.stem_sample_rate, + ), + processed_sample_rate=config.stem_sample_rate, + reference_sample_rate=config.stem_sample_rate, + audiogram=listener.audiogram_right, equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 1])), ) @@ -216,12 +231,12 @@ def _evaluate_song_listener( def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI-RMS metric.""" # Load test songs - with open(config.path.music_valid_file, encoding="utf-8") as fp: + with open(config.path.music_file, encoding="utf-8") as fp: songs = json.load(fp) - songs = pd.DataFrame.from_dict(songs) + songs_df = pd.DataFrame.from_dict(songs) # Load listener data - listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") @@ -232,7 +247,7 @@ def run_calculate_aq(config: DictConfig) -> None: results_file.write_header() song_listener_pair = make_song_listener_list( - songs["Track Name"].tolist(), listener_dict, config.evaluate.small_test + songs_df["Track Name"].tolist(), listener_dict, config.evaluate.small_test ) song_listener_pair = song_listener_pair[ @@ -241,19 +256,19 @@ def run_calculate_aq(config: DictConfig) -> None: for song, listener_id in song_listener_pair: split_dir = "train" - if songs[songs["Track Name"] == song]["Split"].tolist()[0] == "test": + if songs_df[songs_df["Track Name"] == song]["Split"].tolist()[0] == "test": split_dir = "test" listener = listener_dict[listener_id] combined_score, per_instrument_score = _evaluate_song_listener( - song, - listener, - config, - split_dir, - enhanced_folder, + song=song, + listener=listener, + config=config, + split_dir=split_dir, + enhanced_folder=enhanced_folder, ) results_file.add_result( - listener.id, - song, + listener_id=listener.id, + song=song, score=combined_score, instruments_scores=per_instrument_score, ) From b3f171f6d3b5636728a2527e6ecf61b9fed50e8f Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:55:05 +0100 Subject: [PATCH 45/49] evaluate cad1 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/evaluate.py | 112 ++++++------------------ 1 file changed, 29 insertions(+), 83 deletions(-) diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 4798c2f32..2b439a950 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-locals # pylint: disable=import-error -import csv import hashlib import itertools import json @@ -19,87 +18,12 @@ from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.results_support import ResultsFile from clarity.utils.signal_processing import compute_rms, resample logger = logging.getLogger(__name__) -class ResultsFile: - """A utility class for writing results to a CSV file. - - Attributes: - file_name (str): The name of the file to write results to. - """ - - def __init__(self, file_name: str): - """Initialize the ResultsFile instance. - - Args: - file_name (str): The name of the file to write results to. - """ - self.file_name = file_name - - def write_header(self): - """Write the header row to the CSV file.""" - with open(self.file_name, "w", encoding="utf-8", newline="") as csv_file: - csv_writer = csv.writer( - csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - "song", - "listener", - "score", - "left_bass", - "right_bass", - "left_drums", - "right_drums", - "left_other", - "right_other", - "left_vocals", - "right_vocals", - ] - ) - - def add_result( - self, - listener_id: str, - song: str, - score: float, - instruments_scores: dict[str, float], - ): - """Add a result to the CSV file. - - Args: - listener_id (str): The name of the listener who submitted the result. - song (str): The name of the song that the result is for. - score (float): The combined score for the result. - instruments_scores (dict): A dictionary of scores for each instrument - channel in the result. - """ - logger.info(f"The combined score is {score}") - - with open(self.file_name, "a", encoding="utf-8", newline="") as csv_file: - csv_writer = csv.writer( - csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - song, - listener_id, - str(score), - str(instruments_scores["left_bass"]), - str(instruments_scores["right_bass"]), - str(instruments_scores["left_drums"]), - str(instruments_scores["right_drums"]), - str(instruments_scores["left_other"]), - str(instruments_scores["right_other"]), - str(instruments_scores["left_vocals"]), - str(instruments_scores["right_vocals"]), - ] - ) - - def set_song_seed(song: str) -> None: """Set a seed that is unique for the given song""" song_encoded = hashlib.md5(song.encode("utf-8")).hexdigest() @@ -241,10 +165,30 @@ def run_calculate_aq(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") + scores_headers = [ + "song", + "listener", + "score", + "left_bass", + "right_bass", + "left_drums", + "right_drums", + "left_other", + "right_other", + "left_vocals", + "right_vocals", + ] + + results_file_name = "scores.csv" + if config.evaluate.batch_size > 1: + results_file_name = ( + f"scores_{int(config.evaluate.batch) + 1}-{config.evaluate.batch_size}.csv" + ) + results_file = ResultsFile( - f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" + file_name=results_file_name, + header_columns=scores_headers, ) - results_file.write_header() song_listener_pair = make_song_listener_list( songs_df["Track Name"].tolist(), listener_dict, config.evaluate.small_test @@ -267,10 +211,12 @@ def run_calculate_aq(config: DictConfig) -> None: enhanced_folder=enhanced_folder, ) results_file.add_result( - listener_id=listener.id, - song=song, - score=combined_score, - instruments_scores=per_instrument_score, + { + "song": song, + "listener": listener.id, + "score": combined_score, + **per_instrument_score, + } ) From 392b1b54c59789ae4010cc1a37190b59fca9935f Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 13:55:27 +0100 Subject: [PATCH 46/49] evaluate cad1 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 2b439a950..1543a9b96 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -182,7 +182,7 @@ def run_calculate_aq(config: DictConfig) -> None: results_file_name = "scores.csv" if config.evaluate.batch_size > 1: results_file_name = ( - f"scores_{int(config.evaluate.batch) + 1}-{config.evaluate.batch_size}.csv" + f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" ) results_file = ResultsFile( From 0b05c531b89d7de569aa58534534660d23860d4b Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 17:22:31 +0100 Subject: [PATCH 47/49] test evaluate cad1 v0.4 Signed-off-by: Gerardo Roa Dabike --- .../cad1/task1/baseline/test_evaluate.py | 65 ++++++------------- 1 file changed, 19 insertions(+), 46 deletions(-) diff --git a/tests/recipes/cad1/task1/baseline/test_evaluate.py b/tests/recipes/cad1/task1/baseline/test_evaluate.py index 544094eb4..8a3977d69 100644 --- a/tests/recipes/cad1/task1/baseline/test_evaluate.py +++ b/tests/recipes/cad1/task1/baseline/test_evaluate.py @@ -8,42 +8,14 @@ from scipy.io import wavfile from clarity.utils.audiogram import Audiogram, Listener +from clarity.utils.flac_encoder import FlacEncoder from recipes.cad1.task1.baseline.evaluate import ( - ResultsFile, _evaluate_song_listener, make_song_listener_list, set_song_seed, ) -def test_results_file(tmp_path): - """Test the class ResultsFile""" - results_file = tmp_path / "results.csv" - result_file = ResultsFile(results_file.as_posix()) - result_file.write_header() - result_file.add_result( - listener_id="My listener", - song="My favorite song", - score=0.9, - instruments_scores={ - "left_bass": 0.8, - "right_bass": 0.8, - "left_drums": 0.9, - "right_drums": 0.9, - "left_other": 0.8, - "right_other": 0.8, - "left_vocals": 0.95, - "right_vocals": 0.95, - }, - ) - with open(results_file, encoding="utf-8") as file: - contents = file.read() - assert ( - "My favorite song,My listener,0.9,0.8,0.8,0.9,0.9,0.8,0.8,0.95,0.95" - in contents - ) - - @pytest.mark.parametrize( "song,expected_result", [("my favorite song", 83), ("another song", 3)], @@ -88,10 +60,11 @@ def test_make_song_listener_list(): "punk_is_not_dead", "my_music_listener", { + "stem_sample_rate": 44100, + "sample_rate": 44100, "evaluate": {"set_random_seed": True}, "path": {"music_dir": None}, - "sample_rate": 44100, - "nalr": {"sample_rate": 44100}, + "nalr": {"fs": 44100}, }, "test", { @@ -102,14 +75,14 @@ def test_make_song_listener_list(): } }, { - "left_drums": 0.14229422779265366, - "right_drums": 0.15044965630960655, - "left_bass": 0.1333774836344767, - "right_bass": 0.14541827476097585, - "left_other": 0.16310480582621734, - "right_other": 0.15427835764875864, - "left_vocals": 0.12291980372806624, - "right_vocals": 0.1368378217706031, + "left_drums": 0.14229280292204488, + "right_drums": 0.15044867874762802, + "left_bass": 0.13337685099485902, + "right_bass": 0.14541734646032817, + "left_other": 0.16310385596493193, + "right_other": 0.1542791489799909, + "left_vocals": 0.12291878218281638, + "right_vocals": 0.13683790592287856, }, ) ], @@ -143,22 +116,22 @@ def test_evaluate_song_listener( instruments = ["drums", "bass", "other", "vocals"] # Create reference and enhanced wav samples + flac_encoder = FlacEncoder() for lr_instrument in list(expected_results.keys()): # enhanced signals are mono enh_file = ( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_{lr_instrument}.wav" + / f"{listener.id}_{song}_{lr_instrument}.flac" ) enh_file.parent.mkdir(exist_ok=True, parents=True) + with open(Path(enh_file).with_suffix(".txt"), "w", encoding="utf-8") as file: + file.write("1.0") # Using very short 100 ms signals to speed up the test - wavfile.write( - enh_file, - 44100, - np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768, - ) + enh_signal = np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768 + flac_encoder.encode(enh_signal.astype(np.int16), config.sample_rate, enh_file) for instrument in instruments: # reference signals are stereo @@ -183,7 +156,7 @@ def test_evaluate_song_listener( # Combined score assert isinstance(combined_score, float) assert combined_score == pytest.approx( - 0.14358505393391977, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 0.14358442152193474, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance ) # Per instrument score From 26de49b5f12a9854a35065959c4f150fa96445ea Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Tue, 26 Sep 2023 17:39:15 +0100 Subject: [PATCH 48/49] cad1 task2 v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task2/baseline/baseline_utils.py | 12 +- recipes/cad1/task2/baseline/config.yaml | 4 +- recipes/cad1/task2/baseline/evaluate.py | 128 +++++------------- .../task2/baseline/test_baseline_utils.py | 3 +- 4 files changed, 41 insertions(+), 106 deletions(-) diff --git a/recipes/cad1/task2/baseline/baseline_utils.py b/recipes/cad1/task2/baseline/baseline_utils.py index 6b65406e7..5a458fc31 100644 --- a/recipes/cad1/task2/baseline/baseline_utils.py +++ b/recipes/cad1/task2/baseline/baseline_utils.py @@ -93,15 +93,9 @@ def load_listeners_and_scenes( df_scenes = pd.read_json(fp, orient="index") # Load audiograms and scene data for the corresponding split - if config.evaluate.split == "train": - listeners = Listener.load_listener_dict(config.path.listeners_train_file) - scenes = df_scenes[df_scenes["split"] == "train"].to_dict("index") - elif config.evaluate.split == "valid": - listeners = Listener.load_listener_dict(config.path.listeners_valid_file) - scenes = df_scenes[df_scenes["split"] == "valid"].to_dict("index") - elif config.evaluate.split == "test": - listeners = Listener.load_listener_dict(config.path.listeners_test_file) - scenes = df_scenes[df_scenes["split"] == "test"].to_dict("index") + listeners = Listener.load_listener_dict(config.path.listeners_file) + if config.evaluate.split in ["train", "valid", "test"]: + scenes = df_scenes[df_scenes["split"] == config.evaluate.split].to_dict("index") else: raise ValueError(f"Unknown split {config.evaluate.split}") diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index b6722310a..ff84fe766 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -4,9 +4,7 @@ path: metadata_dir: ${path.root}/metadata music_dir: ${path.audio_dir}/music hrtf_dir: ${path.audio_dir}/eBrird - listeners_train_file: ${path.metadata_dir}/listeners.train.json - listeners_valid_file: ${path.metadata_dir}/listeners.valid.json - listeners_test_file: ${path.metadata_dir}/listeners.test.json + listeners_file: ${path.metadata_dir}/listeners.valid.json scenes_file: ${path.metadata_dir}/scenes.json scenes_listeners_file: ${path.metadata_dir}/scenes_listeners.json hrtf_file: ${path.metadata_dir}/eBrird_BRIR.json diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index a434b3f30..b3a202b6c 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -3,7 +3,6 @@ # pylint: disable=import-error from __future__ import annotations -import csv import hashlib import logging from pathlib import Path @@ -11,11 +10,12 @@ import hydra import numpy as np from omegaconf import DictConfig -from scipy.io import wavfile from tqdm import tqdm from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.results_support import ResultsFile from recipes.cad1.task2.baseline.audio_manager import AudioManager from recipes.cad1.task2.baseline.baseline_utils import ( load_hrtf, @@ -28,81 +28,6 @@ logger = logging.getLogger(__name__) -class ResultsFile: - """A utility class for writing results to a CSV file. - - Attributes: - file_name (str): The name of the file to write results to. - """ - - def __init__(self, file_name): - """Initialize the ResultsFile instance. - - Args: - file_name (str): The name of the file to write results to. - """ - self.file_name = file_name - - def write_header(self): - """Write the header row to the CSV file.""" - with open(self.file_name, "w", encoding="utf-8") as csv_f: - csv_writer = csv.writer( - csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - "scene", - "song", - "genre", - "listener", - "score", - "haaqi_left", - "haaqi_right", - ] - ) - - # pylint: disable=too-many-arguments - def add_result( - self, - scene: str, - song: str, - genre: str, - listener: str, - score: float, - haaqi_left: float, - haaqi_right: float, - ): - """Add a result to the CSV file. - - Args: - scene (str): The name of the scene that the result is for. - song (str): The name of the song that the result is for. - genre (str): The genre of the song that the result is for. - listener (str): The name of the listener who submitted the result. - score (float): The combined score for the result. - haaqi_left (float): The HAAQI score for the left channel. - haaqi_right (float): The HAAQI score for the right channel. - """ - - logger.info(f"The combined score for scene {scene}: {score:.4f}") - - with open(self.file_name, "a", encoding="utf-8") as csv_f: - csv_writer = csv.writer( - csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - scene, - song, - genre, - listener, - str(score), - str(haaqi_left), - str(haaqi_right), - ] - ) - - def set_scene_seed(scene: str): """Set a seed that is unique for the given song based on the last 8 characters of the 'md5' @@ -228,10 +153,26 @@ def run_calculate_audio_quality(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") + scores_headers = [ + "scene", + "song", + "genre", + "listener", + "score", + "haaqi_left", + "haaqi_right", + ] + + results_file_name = "scores.csv" + if config.evaluate.batch_size > 1: + results_file_name = ( + f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" + ) + results_file = ResultsFile( - f"scores_{config.evaluate.batch}-{config.evaluate.batch_size}.csv" + file_name=results_file_name, + header_columns=scores_headers, ) - results_file.write_header() # Initialize acoustic scene model car_scene_acoustic = CarSceneAcoustics( @@ -264,15 +205,14 @@ def run_calculate_audio_quality(config: DictConfig) -> None: # Load enhanced signal enhanced_folder = Path("enhanced_signals") / config.evaluate.split - enhanced_song_id = f"{scene_id}_{listener.id}_{current_scene['song']}" - enhanced_song_path = ( - enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.wav" + # Read WAV enhanced signal using scipy.io.wavfile + enhanced_signal, enhanced_sample_rate = read_flac_signal( + enhanced_folder + / f"{listener.id}" + / f"{scene_id}_{listener.id}_{current_scene['song']}.flac" ) - # Read WAV enhanced signal using scipy.io.wavfile - enhanced_sample_rate, enhanced_signal = wavfile.read(enhanced_song_path) - enhanced_signal = enhanced_signal / 32768.0 - assert enhanced_sample_rate == config.sample_rate + assert enhanced_sample_rate == config.enhanced_sample_rate # Evaluate scene aq_score_l, aq_score_r = evaluate_scene( @@ -290,13 +230,15 @@ def run_calculate_audio_quality(config: DictConfig) -> None: # Compute combined score and save score = np.mean([aq_score_r, aq_score_l]) results_file.add_result( - scene_id, - current_scene["song"], - current_scene["song_path"].split("/")[-2], - listener.id, - score=float(score), - haaqi_left=aq_score_l, - haaqi_right=aq_score_r, + { + "scene": scene_id, + "song": current_scene["song"], + "genre": current_scene["song_path"].split("/")[-2], + "listener": listener.id, + "score": float(score), + "haaqi_left": aq_score_l, + "haaqi_right": aq_score_r, + } ) diff --git a/tests/recipes/cad1/task2/baseline/test_baseline_utils.py b/tests/recipes/cad1/task2/baseline/test_baseline_utils.py index 2ae029d68..7f0d527b6 100644 --- a/tests/recipes/cad1/task2/baseline/test_baseline_utils.py +++ b/tests/recipes/cad1/task2/baseline/test_baseline_utils.py @@ -1,6 +1,7 @@ """Test for baseline_utils.py""" from pathlib import Path +# pylint: disable=import-error import librosa import numpy as np import pytest @@ -35,7 +36,7 @@ def test_load_listeners_and_scenes(): { "path": { "scenes_file": (RESOURCES / "scenes.json").as_posix(), - "listeners_train_file": (RESOURCES / "listeners.json").as_posix(), + "listeners_file": (RESOURCES / "listeners.json").as_posix(), "scenes_listeners_file": ( RESOURCES / "scenes_listeners.json" ).as_posix(), From 8ec9c7b05fad428a47684d01d8e82fff73b743de Mon Sep 17 00:00:00 2001 From: Gerardo Roa Dabike Date: Wed, 27 Sep 2023 08:52:14 +0100 Subject: [PATCH 49/49] cad1 test v0.4 Signed-off-by: Gerardo Roa Dabike --- recipes/cad1/task1/baseline/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/cad1/task1/baseline/test.py b/recipes/cad1/task1/baseline/test.py index 0fd4fd98d..47592ad17 100644 --- a/recipes/cad1/task1/baseline/test.py +++ b/recipes/cad1/task1/baseline/test.py @@ -94,9 +94,9 @@ def enhance(config: DictConfig) -> None: # Processing Validation Set # Load listener audiograms and songs - listener_dict = Listener.load_listener_dict(config.path.listeners_test_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) - with open(config.path.music_test_file, encoding="utf-8") as file: + with open(config.path.music_file, encoding="utf-8") as file: song_data = json.load(file) songs_details = pd.DataFrame.from_dict(song_data)