diff --git a/clarity/evaluator/haspi/eb.py b/clarity/evaluator/haspi/eb.py index 32692db92..415afe2cc 100644 --- a/clarity/evaluator/haspi/eb.py +++ b/clarity/evaluator/haspi/eb.py @@ -1288,13 +1288,10 @@ def env_smooth(envelopes: np.ndarray, segment_size: int, sample_rate: float) -> """ # Compute the window - n_samples = int( - np.around(segment_size * (0.001 * sample_rate)) - ) # Segment size in samples - test = n_samples - 2 * np.floor(n_samples / 2) # 0=even, 1=odd - if test > 0: - # Force window length to be even - n_samples = n_samples + 1 + # Segment size in samples + n_samples = int(np.around(segment_size * (0.001 * sample_rate))) + n_samples += n_samples % 2 + window = np.hanning(n_samples) # Raised cosine von Hann window wsum = np.sum(window) # Sum for normalization @@ -1848,6 +1845,7 @@ def bm_covary( correlation = correlation[ int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) ] + unbiased_cross_correlation = np.max(np.abs(correlation * half_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1877,6 +1875,7 @@ def bm_covary( correlation = correlation[ int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) ] + unbiased_cross_correlation = np.max(np.abs(correlation * win_corr)) if (ref_mean_square > small) and (proc_mean_squared > small): # Normalize cross-covariance @@ -1900,7 +1899,7 @@ def bm_covary( ref_mean_square = np.sum(reference_seg**2) * halfsum2 proc_mean_squared = np.sum(processed_seg**2) * halfsum2 - correlation = np.correlate(reference_seg, processed_seg, "full") + correlation = correlate(reference_seg, processed_seg, "full") correlation = correlation[ int(len(reference_seg) - 1 - maxlag) : int(maxlag + len(reference_seg)) ] diff --git a/clarity/utils/signal_processing.py b/clarity/utils/signal_processing.py index fb7fd391f..0589ce267 100644 --- a/clarity/utils/signal_processing.py +++ b/clarity/utils/signal_processing.py @@ -1,4 +1,5 @@ """Signal processing utilities.""" +# pylint: disable=import-error from __future__ import annotations # pylint: disable=import-error diff --git a/recipes/cad1/README.md b/recipes/cad1/README.md index dca206dca..f40609ba6 100644 --- a/recipes/cad1/README.md +++ b/recipes/cad1/README.md @@ -21,37 +21,42 @@ The performance of each system on the validation set is reported below. ### Task 1 - Listening music via headphones -**The overall HAAQI score is 0.3608.** +The overall HAAQI score is: + +- Demucs: **0.2592** +- Open-Unmix: **0.2273** #### Average HAAQI score per song -| Song | HAAQI | -|:------------------------------------------------|:----------:| -| Actions - One Minute Smile | 0.3066 | -| Alexander Ross - Goodbye Bolero | 0.4257 | -| ANiMAL - Rockshow | 0.2389 | -| Clara Berry And Wooldog - Waltz For My Victims | 0.4202 | -| Fergessen - Nos Palpitants | 0.4554 | -| James May - On The Line | 0.3889 | -| Johnny Lokke - Promises & Lies | 0.3395 | -| Leaf - Summerghost | 0.3595 | -| Meaxic - Take A Step | 0.3470 | -| Patrick Talbot - A Reason To Leave | 0.4545 | -| Skelpolu - Human Mistakes | 0.3055 | -| Triviul - Angelsaint | 0.2883 | +| Song | Demucs | Open-UnMix | +|:-----------------------------------------------|:------:|:----------:| +| Actions - One Minute Smile | 0.2485 | 0.2257 | +| Alexander Ross - Goodbye Bolero | 0.3084 | 0.2574 | +| ANiMAL - Rockshow | 0.1843 | 0.1864 | +| Clara Berry And Wooldog - Waltz For My Victims | 0.3094 | 0.2615 | +| Fergessen - Nos Palpitants | 0.3542 | 0.2592 | +| James May - On The Line | 0.2778 | 0.2398 | +| Johnny Lokke - Promises & Lies | 0.2544 | 0.2261 | +| Leaf - Summerghost | 0.2513 | 0.2105 | +| Meaxic - Take A Step | 0.2455 | 0.2239 | +| Patrick Talbot - A Reason To Leave | 0.2673 | 0.2331 | +| Skelpolu - Human Mistakes | 0.2123 | 0.1951 | +| Traffic Experiment - Sirens | 0.2558 | 0.2339 | +| Triviul - Angelsaint | 0.2101 | 0.1955 | +| Young Griffo - Pennies | 0.2499 | 0.2297 | ### Task 2 - Listening music in a car with presence of noise -**The overall HAAQI score is 0.1248.** +**The overall HAAQI score is 0.1423.** #### Average HAAQI score per genre -| Genre | HAAQI | -|:---------------|:----------:| -| Classical | 0.1240 | -| Hip-Hop | 0.1271 | -| Instrumental | 0.1250 | -| International | 0.1267 | -| Orchestral | 0.1121 | -| Pop | 0.1339 | -| Rock | 0.1252 | +| Genre | HAAQI | +|:---------------|:------:| +| Classical | 0.1365 | +| Hip-Hop | 0.1462 | +| Instrumental | 0.1416 | +| International | 0.1432 | +| Orchestral | 0.1329 | +| Pop | 0.1498 | +| Rock | 0.1460 | diff --git a/recipes/cad1/task1/baseline/README.md b/recipes/cad1/task1/baseline/README.md index 7732be324..a8f082706 100644 --- a/recipes/cad1/task1/baseline/README.md +++ b/recipes/cad1/task1/baseline/README.md @@ -15,8 +15,8 @@ To download the data, please visit [here](https://forms.gle/UQkuCxqQVxZtGggPA). Alternatively, you can download the MUSDB18-HQ dataset from the official [SigSep website](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav). If you opt for this alternative, be sure to download the uncompressed wav version. Note that you will need both packages to run the baseline system. -If you need additional music data for training your model, please restrict to the use of [MedleyDB](https://medleydb.weebly.com/) [4] [5], -[BACH10](https://labsites.rochester.edu/air/resource.html) [6] and [FMA-small](https://github.com/mdeff/fma) [7]. +If you need additional music data for training your model, please restrict to the use of [MedleyDB](https://medleydb.weebly.com/) [[4](#4-references)] [[5](#4-references)], +[BACH10](https://labsites.rochester.edu/air/resource.html) [[6](#4-references)] and [FMA-small](https://github.com/mdeff/fma) [[7](#4-references)]. Theses are shared as `cadenza_cad1_task1_augmentation_medleydb.tar.gz`, `cadenza_cad1_task1_augmentation_bach10.tar.gz` and `cadenza_cad1_task1_augmentation_fma_small.tar.gz`. **Keeping the augmentation data restricted to these datasets will ensure that the evaluation is fair for all participants**. @@ -56,7 +56,7 @@ cadenza_data ### 1.2 Additional optional data -* **MedleyDB** contains both MedleyDB versions 1 [[4](#references)] and 2 [[5](#references)] datasets. +* **MedleyDB** contains both MedleyDB versions 1 [[4](#4-references)] and 2 [[5](#4-references)] datasets. Tracks from the MedleyDB dataset are not included in the evaluation set. However, is your responsibility to exclude any song that may be already contained in the training set. @@ -70,7 +70,7 @@ cadenza_data └───Metadata ``` -* **BACH10** contains the BACH10 dataset [[6](#references)]. +* **BACH10** contains the BACH10 dataset [[6](#4-references)]. Tracks from the BACH10 dataset are not included in MUSDB18-HQ and can all be used as training augmentation data. @@ -84,7 +84,7 @@ cadenza_data ├───... ``` -* **FMA Small** contains the FMA small subset of the FMA dataset [[7](references)]. +* **FMA Small** contains the FMA small subset of the FMA dataset [[7](#4-references)]. Tracks from the FMA small dataset are not included in the MUSDB18-HQ. This dataset does not provide independent stems but only the full mix. @@ -123,18 +123,26 @@ Note that we use [hydra](https://hydra.cc/docs/intro/) for config handling. ### 2.1 Enhancement -The baseline enhance simply takes the out-of-the-box [Hybrid Demucs](https://github.com/facebookresearch/demucs) [1] +We offer two baseline systems: + +1. Using the out-of-the-box time-domain [Hybrid Demucs](https://github.com/facebookresearch/demucs) [[1](#4-references)] source separation model distributed on [TorchAudio](https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html) -and applies a simple NAL-R [2] fitting amplification to each VDBO (`vocals`, `drums`, `bass` and `others`) stem. +2. Using the out-of-the-box spectrogram-based [Open-Unmix](https://github.com/sigsep/open-unmix-pytorch) +source separation model (version `umxhq`) distributed through [PyTorch Hub](https://pytorch.org/hub/) -The remixing is performed by summing the amplified VDBO stems. +Both system use the same enhancement strategy; using the music separation model, the baseline system estimates the +VDBO (`vocals`, `drums`, `bass` and `others`) stems. Then, they apply a simple NAL-R [[2](#4-references)] fitting amplification to each of them. +These results on eight mono signals (four from the left channel and four from the right channel). Finally, each signal is downsampled to 24000 Hertz, convert to 16bit precision and +encoded using the lossless FLAC compression. These eight signal are then used for the objective evaluation (HAAQI). -The baseline generates a left and right signal for each VDBO stem and a remixed signal, totalling 9 signals per song-listener. +The baselines also provide a remixing strategy to generate a stereo signal for each listener. This is done by summing +the amplified VDBO stems, where each channel (left and right in stereo) is composed of the addition of the corresponding +four stems. This stereo remixed signal is then used for subjective evaluation (listening panel). To run the baseline enhancement system first, make sure that `paths.root` in `config.yaml` points to where you have installed the Cadenza data. This parameter defaults to the working directory. -You can also define your own `path.exp_folder` to store enhanced -signals and evaluated results. +You can also define your own `path.exp_folder` to store the enhanced signals and evaluated results and select what +music separation model you want to employ. Then run: @@ -158,9 +166,8 @@ The folder `enhanced_signals` will appear in the `exp` folder. ### 2.2 Evaluation -The `evaluate.py` simply takes the signals stored in `enhanced_signals` and computes the HAAQI [[3](#references)] score -for each of the eight left and right VDBO stems. -The average of these eight scores is computed and returned for each signal. +The `evaluate.py` script takes the eight VDBO signals stored in `enhanced_signals` and computes the +HAAQI [[3](#4-references)] score. The final score for the sample is the average of the scores of each stem. To run the evaluation stage, make sure that `path.root` is set in the `config.yaml` file and then run @@ -172,13 +179,19 @@ A csv file containing the eight HAAQI scores and the combined score will be gene To check the HAAQI code, see [here](../../../../clarity/evaluator/haaqi). -Please note: you will not get identical HAAQI scores for the same signals if the random seed is not defined -(in the given recipe, the random seed for each signal is set as the last eight digits of the song md5). -As there are random noises generated within HAAQI, but the differences should be sufficiently small. +Please note: you will not get identical HAAQI scores for the same signals if the random seed is not defined. +This is due to the random noises generated within HAAQI, but the differences should be sufficiently small. +For reproducibility, in the given recipe, the random seed for each signal is set as the last eight digits +of the song md5. + +## 3. Results + +The overall HAAQI score for each baseline is: -The score for the baseline is 0.3608 HAAQI overall. +* Demucs: **0.2592** +* Open-Unmix: **0.2273** -## References +## 4. References * [1] Défossez, A. "Hybrid Spectrogram and Waveform Source Separation". Proceedings of the ISMIR 2021 Workshop on Music Source Separation. [doi:10.48550/arXiv.2111.03600](https://arxiv.org/abs/2111.03600) * [2] Byrne, Denis, and Harvey Dillon. "The National Acoustic Laboratories'(NAL) new procedure for selecting the gain and frequency response of a hearing aid." Ear and hearing 7.4 (1986): 257-265. [doi:10.1097/00003446-198608000-00007](https://doi.org/10.1097/00003446-198608000-00007) diff --git a/recipes/cad1/task1/baseline/config.yaml b/recipes/cad1/task1/baseline/config.yaml index 88de32bf8..bfac09086 100644 --- a/recipes/cad1/task1/baseline/config.yaml +++ b/recipes/cad1/task1/baseline/config.yaml @@ -2,18 +2,20 @@ path: root: ../../cadenza_data_demo/cad1/task1 metadata_dir: ${path.root}/metadata music_dir: ${path.root}/audio/musdb18hq - music_train_file: ${path.metadata_dir}/musdb18.train.json - music_valid_file: ${path.metadata_dir}/musdb18.valid.json - listeners_train_file: ${path.metadata_dir}/listeners.train.json - listeners_valid_file: ${path.metadata_dir}/listeners.valid.json - exp_folder: ./exp # folder to store enhanced signals and final results + music_file: ${path.metadata_dir}/musdb18.valid.json + listeners_file: ${path.metadata_dir}/listeners.valid.json + music_segments_test_file: ${path.metadata_dir}/musdb18.segments.test.json + exp_folder: ./exp_${separator.model} # folder to store enhanced signals and final results +team_id: T001 -sample_rate: 44100 +sample_rate: 44100 # sample rate of the input mixture +stem_sample_rate: 24000 # sample rate output stems +remix_sample_rate: 32000 # sample rate for output remixed signal nalr: nfir: 220 - fs: ${sample_rate} + sample_rate: ${sample_rate} apply_compressor: False compressor: @@ -27,7 +29,6 @@ soft_clip: True separator: model: demucs # demucs or openunmix - sources: [drums, bass, other, vocals] device: ~ evaluate: diff --git a/recipes/cad1/task1/baseline/enhance.py b/recipes/cad1/task1/baseline/enhance.py index 1daf24a92..610ebd67d 100644 --- a/recipes/cad1/task1/baseline/enhance.py +++ b/recipes/cad1/task1/baseline/enhance.py @@ -15,12 +15,19 @@ from omegaconf import DictConfig from scipy.io import wavfile from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB -from torchaudio.transforms import Fade, Resample +from torchaudio.transforms import Fade from clarity.enhancer.compressor import Compressor from clarity.enhancer.nalr import NALR from clarity.utils.audiogram import Audiogram, Listener -from clarity.utils.signal_processing import denormalize_signals, normalize_signal +from clarity.utils.flac_encoder import FlacEncoder +from clarity.utils.signal_processing import ( + clip_signal, + denormalize_signals, + normalize_signal, + resample, + to_16bit, +) from recipes.cad1.task1.baseline.evaluate import make_song_listener_list logger = logging.getLogger(__name__) @@ -28,7 +35,7 @@ def separate_sources( model: torch.nn.Module, - mix: torch.Tensor, + mix: torch.Tensor | ndarray, sample_rate: int, segment: float = 10.0, overlap: float = 0.1, @@ -119,11 +126,11 @@ def get_device(device: str) -> tuple: raise ValueError(f"Unsupported device type: {device}") -def map_to_dict(sources: np.ndarray, sources_list: list[str]) -> dict: +def map_to_dict(sources: ndarray, sources_list: list[str]) -> dict: """Map sources to a dictionary separating audio into left and right channels. Args: - sources (np.ndarray): Signal to be mapped to dictionary. + sources (ndarray): Signal to be mapped to dictionary. sources_list (list): List of strings used to index dictionary. Returns: @@ -142,13 +149,15 @@ def map_to_dict(sources: np.ndarray, sources_list: list[str]) -> dict: # pylint: disable=unused-argument def decompose_signal( - config: DictConfig, model: torch.nn.Module, - signal: np.ndarray, - sample_rate: int, + model_sample_rate: int, + signal: ndarray, + signal_sample_rate: int, device: torch.device, + sources_list: list[str], listener: Listener, -) -> dict[str, np.ndarray]: + normalise: bool = True, +) -> dict[str, ndarray]: """ Decompose signal into 8 stems. @@ -158,47 +167,46 @@ def decompose_signal( HDEMUCS model trained on the MUSDB18 dataset. Args: - config (DictConfig): Configuration object. model (torch.nn.Module): Torch model. - signal (np.ndarray): Signal to be decomposed. - sample_rate (int): Sample frequency. + model_sample_rate (int): Sample rate of the model. + signal (ndarray): Signal to be decomposed. + signal_sample_rate (int): Sample frequency. device (torch.device): Torch device to use for processing. - listener (Listener). + sources_list (list): List of strings used to index dictionary. + listener (Listener): Listener object. + normalise (bool): Whether to normalise the signal. Returns: Dictionary: Indexed by sources with the associated model as values. """ - if config.separator.model == "demucs": - signal, ref = normalize_signal(signal) - - model_sample_rate = ( - model.sample_rate if config.separator.model == "openunmix" else 44100 - ) + # Resample mixture signal to model sample rate + if signal_sample_rate != model_sample_rate: + signal = resample(signal, signal_sample_rate, model_sample_rate) - if sample_rate != model_sample_rate: - resampler = Resample(sample_rate, model_sample_rate) - signal = resampler(signal) + if normalise: + signal, ref = normalize_signal(signal) sources = separate_sources( - model, torch.from_numpy(signal), sample_rate, device=device + model, torch.from_numpy(signal), signal_sample_rate, device=device ) # only one element in the batch sources = sources[0] - if config.separator.model == "demucs": + + if normalise: sources = denormalize_signals(sources, ref) - signal_stems = map_to_dict(sources, config.separator.sources) + signal_stems = map_to_dict(sources, sources_list) return signal_stems def apply_baseline_ha( enhancer: NALR, compressor: Compressor, - signal: np.ndarray, + signal: ndarray, audiogram: Audiogram, apply_compressor: bool = False, -) -> np.ndarray: +) -> ndarray: """ Apply NAL-R prescription hearing aid to a signal. @@ -206,14 +214,12 @@ def apply_baseline_ha( enhancer: A NALR object that enhances the signal. compressor: A Compressor object that compresses the signal. signal: An ndarray representing the audio signal. - listener_audiogram: An ndarray representing the listener's audiogram. - cfs: An ndarray of center frequencies. + audiogram: An Audiogram object representing the listener's audiogram. apply_compressor: A boolean indicating whether to include the compressor. Returns: An ndarray representing the processed signal. """ - print("XXX", audiogram) nalr_fir, _ = enhancer.build(audiogram) proc_signal = enhancer.apply(nalr_fir, signal) if apply_compressor: @@ -234,8 +240,7 @@ def process_stems_for_listener( stems (dict) : Dictionary of stems enhancer (NALR) : NAL-R prescription hearing aid compressor (Compressor) : Compressor - listener: Listener object - cfs (np.ndarray) : Center frequencies + listener (Listener) : Listener object. apply_compressor (bool) : Whether to apply the compressor Returns: processed_sources (dict) : Dictionary of processed stems @@ -261,27 +266,84 @@ def process_stems_for_listener( return processed_stems -def clip_signal(signal: np.ndarray, soft_clip: bool = False) -> tuple[np.ndarray, int]: - """Clip and save the processed stems. +def remix_signal(stems: dict) -> ndarray: + """ + Function to remix signal. It takes the eight stems + and combines them into a stereo signal. Args: - signal (np.ndarray): Signal to be clipped and saved. - soft_clip (bool): Whether to use soft clipping. + stems (dict) : Dictionary of stems Returns: - signal (np.ndarray): Clipped signal. - n_clipped (int): Number of samples clipped. + (ndarray) : Remixed signal + + """ + n_samples = stems[list(stems.keys())[0]].shape[0] + out_left, out_right = np.zeros(n_samples), np.zeros(n_samples) + for stem_str, stem_signal in stems.items(): + if stem_str.startswith("l"): + out_left += stem_signal + else: + out_right += stem_signal + + return np.stack([out_left, out_right], axis=1) + + +def save_flac_signal( + signal: ndarray, + filename: Path, + signal_sample_rate: int, + output_sample_rate: int, + do_clip_signal: bool = False, + do_soft_clip: bool = False, + do_scale_signal: bool = False, +) -> None: """ + Function to save output signals. + + - The output signal will be resample to ``output_sample_rate`` + - The output signal will be clipped to [-1, 1] if ``do_clip_signal`` is True + and use soft clipped if ``do_soft_clip`` is True. Note that if + ``do_clip_signal`` is False, ``do_soft_clip`` will be ignored. + Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored. + - The output signal will be scaled to [-1, 1] if ``do_scale_signal`` is True. + If signal is scale, the scale factor will be saved in a TXT file. + Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored. + - The output signal will be saved as a FLAC file. - if soft_clip: - signal = np.tanh(signal) - n_clipped = np.sum(np.abs(signal) > 1.0) - np.clip(signal, -1.0, 1.0, out=signal) - return signal, int(n_clipped) + Args: + signal (np.ndarray) : Signal to save + filename (Path) : Path to save signal + signal_sample_rate (int) : Sample rate of the input signal + output_sample_rate (int) : Sample rate of the output signal + do_clip_signal (bool) : Whether to clip signal + do_soft_clip (bool) : Whether to apply soft clipping + do_scale_signal (bool) : Whether to scale signal + """ + # Resample signal to expected output sample rate + if signal_sample_rate != output_sample_rate: + signal = resample(signal, signal_sample_rate, output_sample_rate) + + if do_scale_signal: + # Scale stem signal + max_value = np.max(np.abs(signal)) + signal = signal / max_value + + # Save scale factor + with open(filename.with_suffix(".txt"), "w", encoding="utf-8") as file: + file.write(f"{max_value}") + + elif do_clip_signal: + # Clip the signal + signal, n_clipped = clip_signal(signal, do_soft_clip) + if n_clipped > 0: + logger.warning(f"Writing {filename}: {n_clipped} samples clipped") + # Convert signal to 16-bit integer + signal = to_16bit(signal) -def to_16bit(signal: np.ndarray) -> np.ndarray: - return (32768.0 * signal).astype(np.int16) + # Create flac encoder object to compress and save the signal + FlacEncoder().encode(signal, output_sample_rate, filename) @hydra.main(config_path="", config_name="config") @@ -298,47 +360,39 @@ def enhance(config: DictConfig) -> None: - right channel vocal, drums, bass, and other stems """ + if config.separator.model not in ["demucs", "openunmix"]: + raise ValueError(f"Separator model {config.separator.model} not supported.") + enhanced_folder = Path("enhanced_signals") enhanced_folder.mkdir(parents=True, exist_ok=True) - # Training stage - # - # The baseline is using an off-the-shelf model trained on the MUSDB18 dataset - # Training listeners and song are not necessary in this case. - # - # Training songs and audiograms can be read like this: - # - # with open(config.path.listeners_train_file, "r", encoding="utf-8") as file: - # listener_train_audiograms = json.load(file) - # - # with open(config.path.music_train_file, "r", encoding="utf-8") as file: - # song_data = json.load(file) - # songs_train = pd.DataFrame.from_dict(song_data) - # - # train_song_listener_pairs = make_song_listener_list( - # songs_train['Track Name'], listener_train_audiograms - # ) - if config.separator.model == "demucs": separation_model = HDEMUCS_HIGH_MUSDB.get_model() - else: + model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate + sources_order = separation_model.sources + normalise = True + elif config.separator.model == "openunmix": separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0) + model_sample_rate = separation_model.sample_rate + sources_order = ["vocals", "drums", "bass", "other"] + normalise = False + else: + raise ValueError(f"Separator model {config.separator.model} not supported.") + device, _ = get_device(config.separator.device) separation_model.to(device) # Processing Validation Set # Load listener audiograms and songs - listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) - with open(config.path.music_valid_file, encoding="utf-8") as file: + with open(config.path.music_file, encoding="utf-8") as file: song_data = json.load(file) - songs_valid = pd.DataFrame.from_dict(song_data) + songs_df = pd.DataFrame.from_dict(song_data) - valid_song_listener_pairs = make_song_listener_list( - songs_valid["Track Name"], listener_dict - ) + song_listener_pairs = make_song_listener_list(songs_df["Track Name"], listener_dict) # Select a batch to process - valid_song_listener_pairs = valid_song_listener_pairs[ + song_listener_pairs = song_listener_pairs[ config.evaluate.batch :: config.evaluate.batch_size ] @@ -348,8 +402,8 @@ def enhance(config: DictConfig) -> None: # Decompose each song into left and right vocal, drums, bass, and other stems # and process each stem for the listener prev_song_name = None - num_song_list_pair = len(valid_song_listener_pairs) - for idx, song_listener in enumerate(valid_song_listener_pairs, 1): + num_song_list_pair = len(song_listener_pairs) + for idx, song_listener in enumerate(song_listener_pairs, 1): song_name, listener_name = song_listener logger.info( f"[{idx:03d}/{num_song_list_pair:03d}] " @@ -361,14 +415,15 @@ def enhance(config: DictConfig) -> None: # Find the music split directory split_directory = ( "test" - if songs_valid.loc[songs_valid["Track Name"] == song_name, "Split"].iloc[0] + if songs_df.loc[songs_df["Track Name"] == song_name, "Split"].iloc[0] == "test" else "train" ) - # Read the mixture signal - # Convert to 32-bit floating point and transpose - # from [samples, channels] to [channels, samples] + # Baseline Steps + # 1. Decompose the mixture signal into vocal, drums, bass, and other stems + # We validate if 2 consecutive signals are the same to avoid + # decomposing the same song multiple times if prev_song_name != song_name: # Decompose song only once prev_song_name = song_name @@ -383,17 +438,20 @@ def enhance(config: DictConfig) -> None: assert sample_rate == config.sample_rate stems: dict[str, ndarray] = decompose_signal( - config, separation_model, + model_sample_rate, mixture_signal, sample_rate, device, + sources_order, listener, + normalise, ) - # Baseline applies NALR prescription to each stem instead of using the - # listener's audiograms in the decomposition. This stem can be skipped - # if the listener's audiograms are used in the decomposition + # 2. Apply NAL-R prescription to each stem + # Baseline applies NALR prescription to each stem instead of using the + # listener's audiograms in the decomposition. This step can be skipped + # if the listener's audiograms are used in the decomposition processed_stems = process_stems_for_listener( stems, enhancer, @@ -402,42 +460,41 @@ def enhance(config: DictConfig) -> None: config.apply_compressor, ) - # save processed stems - n_samples = processed_stems[list(processed_stems.keys())[0]].shape[0] - output_left, output_right = np.zeros(n_samples), np.zeros(n_samples) + # 3. Save processed stems for stem_str, stem_signal in processed_stems.items(): - if stem_str.startswith("l"): - output_left += stem_signal - else: - output_right += stem_signal - filename = ( enhanced_folder / f"{listener.id}" / f"{song_name}" - / f"{listener.id}_{song_name}_{stem_str}.wav" + / f"{listener.id}_{song_name}_{stem_str}.flac" ) filename.parent.mkdir(parents=True, exist_ok=True) + save_flac_signal( + signal=stem_signal, + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.stem_sample_rate, + do_scale_signal=True, + ) - # Clip and save stem signals - clipped_signal, n_clipped = clip_signal(stem_signal, config.soft_clip) - if n_clipped > 0: - logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - wavfile.write(filename, config.sample_rate, to_16bit(clipped_signal)) + # 4. Remix Signal + enhanced = remix_signal(processed_stems) - enhanced = np.stack([output_left, output_right], axis=1) + # 5. Save enhanced (remixed) signal filename = ( enhanced_folder / f"{listener.id}" / f"{song_name}" - / f"{listener.id}_{song_name}_remix.wav" + / f"{listener.id}_{song_name}_remix.flac" + ) + save_flac_signal( + signal=enhanced, + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.remix_sample_rate, + do_clip_signal=True, + do_soft_clip=config.soft_clip, ) - - # clip and save enhanced signal - clipped_signal, n_clipped = clip_signal(enhanced, config.soft_clip) - if n_clipped > 0: - logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - wavfile.write(filename, config.sample_rate, to_16bit(clipped_signal)) # pylint: disable = no-value-for-parameter diff --git a/recipes/cad1/task1/baseline/evaluate.py b/recipes/cad1/task1/baseline/evaluate.py index 7c41dd285..1543a9b96 100644 --- a/recipes/cad1/task1/baseline/evaluate.py +++ b/recipes/cad1/task1/baseline/evaluate.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-locals # pylint: disable=import-error -import csv import hashlib import itertools import json @@ -18,87 +17,13 @@ from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener -from clarity.utils.signal_processing import compute_rms +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.results_support import ResultsFile +from clarity.utils.signal_processing import compute_rms, resample logger = logging.getLogger(__name__) -class ResultsFile: - """A utility class for writing results to a CSV file. - - Attributes: - file_name (str): The name of the file to write results to. - """ - - def __init__(self, file_name: str): - """Initialize the ResultsFile instance. - - Args: - file_name (str): The name of the file to write results to. - """ - self.file_name = file_name - - def write_header(self): - """Write the header row to the CSV file.""" - with open(self.file_name, "w", encoding="utf-8", newline="") as csv_file: - csv_writer = csv.writer( - csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - "song", - "listener", - "score", - "left_bass", - "right_bass", - "left_drums", - "right_drums", - "left_other", - "right_other", - "left_vocals", - "right_vocals", - ] - ) - - def add_result( - self, - listener_id: str, - song: str, - score: float, - instruments_scores: dict[str, float], - ): - """Add a result to the CSV file. - - Args: - listener_id (str): The name of the listener who submitted the result. - song (str): The name of the song that the result is for. - score (float): The combined score for the result. - instruments_scores (dict): A dictionary of scores for each instrument - channel in the result. - """ - logger.info(f"The combined score is {score}") - - with open(self.file_name, "a", encoding="utf-8", newline="") as csv_file: - csv_writer = csv.writer( - csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - song, - listener_id, - str(score), - str(instruments_scores["left_bass"]), - str(instruments_scores["right_bass"]), - str(instruments_scores["left_drums"]), - str(instruments_scores["right_drums"]), - str(instruments_scores["left_other"]), - str(instruments_scores["right_other"]), - str(instruments_scores["left_vocals"]), - str(instruments_scores["right_vocals"]), - ] - ) - - def set_song_seed(song: str) -> None: """Set a seed that is unique for the given song""" song_encoded = hashlib.md5(song.encode("utf-8")).hexdigest() @@ -161,47 +86,61 @@ def _evaluate_song_listener( ) reference_signal = (reference_signal / 32768.0).astype(np.float32) - # Load enhanced instrument signals - # Load left channel - sample_rate_left_enhanced_signal, left_enhanced_signal = wavfile.read( + # Read left instrument enhanced + left_enhanced_signal, sample_rate_left_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_left_{instrument}.wav" + / f"{listener.id}_{song}_left_{instrument}.flac" ) - left_enhanced_signal = (left_enhanced_signal / 32768.0).astype(np.float32) - # Load right channel - sample_rate_right_enhanced_signal, right_enhanced_signal = wavfile.read( + # Read right instrument enhanced + right_enhanced_signal, sample_rate_right_enhanced_signal = read_flac_signal( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_right_{instrument}.wav" + / f"{listener.id}_{song}_right_{instrument}.flac" ) - right_enhanced_signal = (right_enhanced_signal / 32768.0).astype(np.float32) - assert ( - sample_rate_reference_signal - == sample_rate_left_enhanced_signal - == sample_rate_right_enhanced_signal - == config.sample_rate - ) + if ( + sample_rate_left_enhanced_signal + != sample_rate_right_enhanced_signal + != config.stem_sample_rate + ): + raise ValueError( + "The sample rates of the left and right enhanced signals are not " + "the same" + ) + + if sample_rate_reference_signal != config.sample_rate: + raise ValueError( + f"The sample rate of the reference signal is not {config.sample_rate}" + ) per_instrument_score[f"left_{instrument}"] = compute_haaqi( - left_enhanced_signal, - reference_signal[:, 0], - config.sample_rate, - config.sample_rate, - listener.audiogram_left, + processed_signal=left_enhanced_signal, + reference_signal=resample( + reference_signal[:, 0], + sample_rate_reference_signal, + config.stem_sample_rate, + ), + processed_sample_rate=config.stem_sample_rate, + reference_sample_rate=config.stem_sample_rate, + audiogram=listener.audiogram_left, equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 0])), ) + per_instrument_score[f"right_{instrument}"] = compute_haaqi( - right_enhanced_signal, - reference_signal[:, 1], - config.sample_rate, - config.sample_rate, - listener.audiogram_right, + processed_signal=right_enhanced_signal, + reference_signal=resample( + reference_signal[:, 1], + sample_rate_reference_signal, + config.stem_sample_rate, + ), + processed_sample_rate=config.stem_sample_rate, + reference_sample_rate=config.stem_sample_rate, + audiogram=listener.audiogram_right, equalisation=1, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, 1])), ) @@ -216,23 +155,43 @@ def _evaluate_song_listener( def run_calculate_aq(config: DictConfig) -> None: """Evaluate the enhanced signals using the HAAQI-RMS metric.""" # Load test songs - with open(config.path.music_valid_file, encoding="utf-8") as fp: + with open(config.path.music_file, encoding="utf-8") as fp: songs = json.load(fp) - songs = pd.DataFrame.from_dict(songs) + songs_df = pd.DataFrame.from_dict(songs) # Load listener data - listener_dict = Listener.load_listener_dict(config.path.listeners_valid_file) + listener_dict = Listener.load_listener_dict(config.path.listeners_file) enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") + scores_headers = [ + "song", + "listener", + "score", + "left_bass", + "right_bass", + "left_drums", + "right_drums", + "left_other", + "right_other", + "left_vocals", + "right_vocals", + ] + + results_file_name = "scores.csv" + if config.evaluate.batch_size > 1: + results_file_name = ( + f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" + ) + results_file = ResultsFile( - f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" + file_name=results_file_name, + header_columns=scores_headers, ) - results_file.write_header() song_listener_pair = make_song_listener_list( - songs["Track Name"].tolist(), listener_dict, config.evaluate.small_test + songs_df["Track Name"].tolist(), listener_dict, config.evaluate.small_test ) song_listener_pair = song_listener_pair[ @@ -241,21 +200,23 @@ def run_calculate_aq(config: DictConfig) -> None: for song, listener_id in song_listener_pair: split_dir = "train" - if songs[songs["Track Name"] == song]["Split"].tolist()[0] == "test": + if songs_df[songs_df["Track Name"] == song]["Split"].tolist()[0] == "test": split_dir = "test" listener = listener_dict[listener_id] combined_score, per_instrument_score = _evaluate_song_listener( - song, - listener, - config, - split_dir, - enhanced_folder, + song=song, + listener=listener, + config=config, + split_dir=split_dir, + enhanced_folder=enhanced_folder, ) results_file.add_result( - listener.id, - song, - score=combined_score, - instruments_scores=per_instrument_score, + { + "song": song, + "listener": listener.id, + "score": combined_score, + **per_instrument_score, + } ) diff --git a/recipes/cad1/task1/baseline/test.py b/recipes/cad1/task1/baseline/test.py new file mode 100644 index 000000000..47592ad17 --- /dev/null +++ b/recipes/cad1/task1/baseline/test.py @@ -0,0 +1,241 @@ +""" Run the baseline enhancement. """ +from __future__ import annotations + +# pylint: disable=import-error +# pylint: disable=too-many-function-args +import json +import logging +import shutil +from pathlib import Path + +import hydra +import numpy as np +import pandas as pd +import torch +from omegaconf import DictConfig +from scipy.io import wavfile +from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB + +from clarity.enhancer.compressor import Compressor +from clarity.enhancer.nalr import NALR +from clarity.utils.audiogram import Listener +from recipes.cad1.task1.baseline.enhance import ( + decompose_signal, + get_device, + process_stems_for_listener, + remix_signal, + save_flac_signal, +) +from recipes.cad1.task1.baseline.evaluate import make_song_listener_list + +# pylint: disable=too-many-locals + +logger = logging.getLogger(__name__) + + +def pack_submission( + team_id: str, + root_dir: str | Path, + base_dir: str | Path = ".", +) -> None: + """ + Pack the submission files into an archive file. + + Args: + team_id (str): Team ID. + root_dir (str | Path): Root directory of the archived file. + base_dir (str | Path): Base directory to archive. Defaults to ".". + """ + # Pack the submission files + logger.info(f"Packing submission files for team {team_id}...") + shutil.make_archive( + f"submission_{team_id}", + "zip", + root_dir=root_dir, + base_dir=base_dir, + ) + + +@hydra.main(config_path="", config_name="config") +def enhance(config: DictConfig) -> None: + """ + Run the music enhancement. + The system decomposes the music into vocal, drums, bass, and other stems. + Then, the NAL-R prescription procedure is applied to each stem. + Args: + config (dict): Dictionary of configuration options for enhancing music. + + Returns 8 stems for each song: + - left channel vocal, drums, bass, and other stems + - right channel vocal, drums, bass, and other stems + """ + + if config.separator.model not in ["demucs", "openunmix"]: + raise ValueError(f"Separator model {config.separator.model} not supported.") + + enhanced_folder = Path("enhanced_signals") / "evaluation" + enhanced_folder.mkdir(parents=True, exist_ok=True) + + if config.separator.model == "demucs": + separation_model = HDEMUCS_HIGH_MUSDB.get_model() + model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate + sources_order = separation_model.sources + normalise = True + elif config.separator.model == "openunmix": + separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0) + model_sample_rate = separation_model.sample_rate + sources_order = ["vocals", "drums", "bass", "other"] + normalise = False + else: + raise ValueError(f"Separator model {config.separator.model} not supported.") + + device, _ = get_device(config.separator.device) + separation_model.to(device) + + # Processing Validation Set + # Load listener audiograms and songs + listener_dict = Listener.load_listener_dict(config.path.listeners_file) + + with open(config.path.music_file, encoding="utf-8") as file: + song_data = json.load(file) + songs_details = pd.DataFrame.from_dict(song_data) + + with open(config.path.music_segments_test_file, encoding="utf-8") as file: + songs_segments = json.load(file) + + song_listener_pairs = make_song_listener_list( + songs_details["Track Name"], listener_dict + ) + # Select a batch to process + song_listener_pairs = song_listener_pairs[ + config.evaluate.batch :: config.evaluate.batch_size + ] + + # Create hearing aid objects + enhancer = NALR(**config.nalr) + compressor = Compressor(**config.compressor) + + # Decompose each song into left and right vocal, drums, bass, and other stems + # and process each stem for the listener + prev_song_name = None + num_song_list_pair = len(song_listener_pairs) + for idx, song_listener in enumerate(song_listener_pairs, 1): + song_name, listener_name = song_listener + logger.info( + f"[{idx:03d}/{num_song_list_pair:03d}] " + f"Processing {song_name} for {listener_name}..." + ) + # Get the listener's audiogram + listener = listener_dict[listener_name] + + # Find the music split directory + split_directory = ( + "test" + if songs_details.loc[ + songs_details["Track Name"] == song_name, "Split" + ].iloc[0] + == "test" + else "train" + ) + + # Baseline Steps + # 1. Decompose the mixture signal into vocal, drums, bass, and other stems + # We validate if 2 consecutive signals are the same to avoid + # decomposing the same song multiple times + if prev_song_name != song_name: + # Decompose song only once + prev_song_name = song_name + + sample_rate, mixture_signal = wavfile.read( + Path(config.path.music_dir) + / split_directory + / song_name + / "mixture.wav" + ) + mixture_signal = (mixture_signal / 32768.0).astype(np.float32).T + assert sample_rate == config.sample_rate + + # Decompose mixture signal into stems + stems = decompose_signal( + separation_model, + model_sample_rate, + mixture_signal, + sample_rate, + device, + sources_order, + listener, + normalise, + ) + + # 2. Apply NAL-R prescription to each stem + # Baseline applies NALR prescription to each stem instead of using the + # listener's audiograms in the decomposition. This step can be skipped + # if the listener's audiograms are used in the decomposition + processed_stems = process_stems_for_listener( + stems, + enhancer, + compressor, + listener, + config.apply_compressor, + ) + + # 3. Save processed stems + for stem_str, stem_signal in processed_stems.items(): + filename = ( + enhanced_folder + / f"{listener.id}" + / f"{song_name}" + / f"{listener.id}_{song_name}_{stem_str}.flac" + ) + filename.parent.mkdir(parents=True, exist_ok=True) + start = songs_segments[song_name]["objective_evaluation"]["start"] + end = songs_segments[song_name]["objective_evaluation"]["end"] + save_flac_signal( + signal=stem_signal[ + int(start * config.sample_rate) : int(end * config.sample_rate) + ], + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.stem_sample_rate, + do_scale_signal=True, + ) + + # 3. Remix Signal + enhanced = remix_signal(processed_stems) + + # 5. Save enhanced (remixed) signal + filename = ( + enhanced_folder + / f"{listener.id}" + / f"{song_name}" + / f"{listener.id}_{song_name}_remix.flac" + ) + start = songs_segments[song_name]["subjective_evaluation"]["start"] + end = songs_segments[song_name]["subjective_evaluation"]["end"] + save_flac_signal( + signal=enhanced[ + int(start * config.sample_rate) : int(end * config.sample_rate) + ], + filename=filename, + signal_sample_rate=config.sample_rate, + output_sample_rate=config.remix_sample_rate, + do_clip_signal=True, + do_soft_clip=config.soft_clip, + ) + + pack_submission( + team_id=config.team_id, + root_dir=enhanced_folder.parent, + base_dir=enhanced_folder.name, + ) + + logger.info("Evaluation complete.!!") + logger.info( + f"Please, submit the file submission_{config.team_id}.zip to the challenge " + "using the link provided. Thank you.!!" + ) + + +# pylint: disable = no-value-for-parameter +if __name__ == "__main__": + enhance() diff --git a/recipes/cad1/task2/baseline/README.md b/recipes/cad1/task2/baseline/README.md index 0dfe7aaaa..bcb910074 100644 --- a/recipes/cad1/task2/baseline/README.md +++ b/recipes/cad1/task2/baseline/README.md @@ -8,8 +8,8 @@ For more information please visit the [challenge website](https://cadenzachallen ### 1.1 Obtaining the CAD1 - Task2 data -The music dataset for the First Cadenza Challenge - Task 2 is based on the small subset of the FMA [2] dataset -(FMA-small) and the MTG-Jamendo dataset [4]. The dataset contains 1000 samples from seven musical genres, +The music dataset for the First Cadenza Challenge - Task 2 is based on the small subset of the FMA [[2](#4-references)] dataset +(FMA-small) and the MTG-Jamendo dataset [[4](#4-references)]. The dataset contains 1000 samples from seven musical genres, totalling 7000 songs with a distribution of 80% / 10% / 10% for `train`, `valid` and `test`. From FMA small: @@ -82,17 +82,18 @@ If you have an Anaconda or Miniconda environment, you can install them as: * conda install -c conda-forge ffmpeg * conda install -c conda-forge libsndfile -```bash - ### 2.1 Enhancement The objective of the enhancement stage is takes a song and optimise it to a listener hearing characteristics -knowing metadata information about the car noise scenario (you won't have access to noise signal), head +knowing metadata information about the car noise scenario (note that you won't have access to noise signal), head rotation of the listener and the SNR of the enhanced music and the noise at the hearing aid microphones. -In the baseline, we simply attenuate the song according to the average hearing loss and save it in 16-bit PCM WAV format. +In the baseline, we attenuate the song according to the average hearing loss. The output are stereo signals +that we save usi ng 32000 Hertz sample rate, 16bit precision, and we encoded it using the lossless FLAC compression. This attenuation prevents some clipping in the hearing aid output signal. +The resulting signals are used for both, the objective (HAAQI) and subjective (listening panel) evaluation. + To run the baseline enhancement system first, make sure that `paths.root` in `config.yaml` points to where you have installed the Cadenza data foer the task2. This parameter defaults to one level above the recipe for the demo data. You can also define your own `path.exp_folder` to store enhanced and evaluated signal results. @@ -120,8 +121,9 @@ The folder `enhanced_signals` will appear in the `exp` folder. ### 2.2 Evaluation The `evaluate.py` module takes the enhanced signals and adds the room impulses and the car noise using -the expected SNR. It then passes that signal through a fixed hearing aid. The hearing aid output and -the reference song are used to compute the HAAQI [2] score. +the expected SNR. It then passes that signal through a fixed hearing aid. The hearing aid is composed of +NAL-R [[1](#4-references)] prescription and compression. The hearing aid output signal and +the reference song are used to compute the HAAQI [[2](#4-references)] score. To run the evaluation stage, make sure that `path.root` is set in the `config.yaml` file and then run @@ -138,9 +140,9 @@ Please note: you will not get identical HAAQI scores for the same signals if the (in the given recipe, the random seed for each signal is set as the last eight digits of the song md5). As there are random noises generated within HAAQI, but the differences should be sufficiently small. -The overall HAAQI score for baseline is 0.1248. +The overall HAAQI score for baseline is **0.1423**. -## References +## 4. References * [1] Byrne, Denis, and Harvey Dillon. "The National Acoustic Laboratories'(NAL) new procedure for selecting the gain and frequency response of a hearing aid." Ear and hearing 7.4 (1986): 257-265. [doi:10.1097/00003446-198608000-00007](https://doi.org/10.1097/00003446-198608000-00007) * [2] Kates J M, Arehart K H. "The Hearing-Aid Audio Quality Index (HAAQI)". IEEE/ACM transactions on audio, speech, and language processing, 24(2), 354–365. [doi:10.1109/TASLP.2015.2507858](https://doi.org/10.1109%2FTASLP.2015.2507858) diff --git a/recipes/cad1/task2/baseline/baseline_utils.py b/recipes/cad1/task2/baseline/baseline_utils.py index fe67c02f3..5a458fc31 100644 --- a/recipes/cad1/task2/baseline/baseline_utils.py +++ b/recipes/cad1/task2/baseline/baseline_utils.py @@ -1,6 +1,7 @@ """Utility functions for the baseline model.""" from __future__ import annotations +# pylint: disable=import-error import json import logging import warnings @@ -13,9 +14,6 @@ from clarity.utils.audiogram import Listener -# pylint: disable=import-error - - logger = logging.getLogger(__name__) @@ -40,7 +38,7 @@ def read_mp3( str(file_path), sr=sample_rate, mono=False, - res_type="kaiser_best", + res_type="soxr_hq", dtype=np.float32, ) except Exception as error: @@ -95,12 +93,9 @@ def load_listeners_and_scenes( df_scenes = pd.read_json(fp, orient="index") # Load audiograms and scene data for the corresponding split - if config.evaluate.split == "train": - listeners = Listener.load_listener_dict(config.path.listeners_train_file) - scenes = df_scenes[df_scenes["split"] == "train"].to_dict("index") - elif config.evaluate.split == "valid": - listeners = Listener.load_listener_dict(config.path.listeners_valid_file) - scenes = df_scenes[df_scenes["split"] == "valid"].to_dict("index") + listeners = Listener.load_listener_dict(config.path.listeners_file) + if config.evaluate.split in ["train", "valid", "test"]: + scenes = df_scenes[df_scenes["split"] == config.evaluate.split].to_dict("index") else: raise ValueError(f"Unknown split {config.evaluate.split}") diff --git a/recipes/cad1/task2/baseline/car_scene_acoustics.py b/recipes/cad1/task2/baseline/car_scene_acoustics.py index 9fe8cc3e1..8bf5510ac 100644 --- a/recipes/cad1/task2/baseline/car_scene_acoustics.py +++ b/recipes/cad1/task2/baseline/car_scene_acoustics.py @@ -112,9 +112,7 @@ def apply_hearing_aid(self, signal: np.ndarray, audiogram: Audiogram) -> np.ndar Args: signal (np.ndarray): The audio signal to be enhanced. - audiogram (np.ndarray): An audiogram used to configure the NALR object. - center_frequencies (np.ndarray): An array of center frequencies - used to configure the NALR object. + audiogram (Audiogram): The audiogram of the listener. Returns: np.ndarray: The enhanced audio signal. diff --git a/recipes/cad1/task2/baseline/config.yaml b/recipes/cad1/task2/baseline/config.yaml index 36a5ceb2f..ff84fe766 100644 --- a/recipes/cad1/task2/baseline/config.yaml +++ b/recipes/cad1/task2/baseline/config.yaml @@ -4,18 +4,20 @@ path: metadata_dir: ${path.root}/metadata music_dir: ${path.audio_dir}/music hrtf_dir: ${path.audio_dir}/eBrird - listeners_train_file: ${path.metadata_dir}/listeners.train.json - listeners_valid_file: ${path.metadata_dir}/listeners.valid.json + listeners_file: ${path.metadata_dir}/listeners.valid.json scenes_file: ${path.metadata_dir}/scenes.json scenes_listeners_file: ${path.metadata_dir}/scenes_listeners.json hrtf_file: ${path.metadata_dir}/eBrird_BRIR.json exp_folder: ./exp # folder to store enhanced signals and final results -sample_rate: 44100 +team_id: T001 + +sample_rate: 44100 # sample rate of the input signal +enhanced_sample_rate: 32000 # sample rate for the enhanced output signal nalr: nfir: 220 - fs: ${sample_rate} + sample_rate: ${sample_rate} compressor: threshold: 0.7 diff --git a/recipes/cad1/task2/baseline/enhance.py b/recipes/cad1/task2/baseline/enhance.py index eda3ac47e..6c85a6b88 100644 --- a/recipes/cad1/task2/baseline/enhance.py +++ b/recipes/cad1/task2/baseline/enhance.py @@ -12,10 +12,11 @@ import numpy as np import pyloudnorm as pyln from omegaconf import DictConfig -from scipy.io import wavfile from tqdm import tqdm from clarity.utils.audiogram import Listener +from clarity.utils.flac_encoder import FlacEncoder +from clarity.utils.signal_processing import clip_signal, resample, to_16bit from recipes.cad1.task2.baseline.baseline_utils import ( make_scene_listener_list, read_mp3, @@ -30,7 +31,7 @@ def compute_average_hearing_loss(listener: Listener) -> float: Compute the average hearing loss of a listener. Args: - listener (dict): The audiogram of the listener. + listener (Listener): The listener. Returns: average_hearing_loss (float): The average hearing loss of the listener. @@ -56,7 +57,7 @@ def enhance_song( Args: waveform (np.ndarray): The waveform of the song. - listener_dict (dict): The audiograms of the listener. + listener (Listener): The listener. config (dict): Dictionary of configuration options for enhancing music. Returns: @@ -110,6 +111,7 @@ def enhance(config: DictConfig) -> None: config.evaluate.batch :: config.evaluate.batch_size ] + flac_encoder = FlacEncoder() for scene_id, listener_id in tqdm(scene_listener_pairs): current_scene = scenes[scene_id] listener = listener_dict[listener_id] @@ -121,23 +123,26 @@ def enhance(config: DictConfig) -> None: out_l, out_r = enhance_song( waveform=song_waveform, listener=listener, config=config ) - enhanced = np.stack([out_l, out_r], axis=1) - filename = f"{scene_id}_{listener.id}_{current_scene['song']}.wav" + # Save the enhanced song enhanced_folder_listener = enhanced_folder / f"{listener.id}" enhanced_folder_listener.mkdir(parents=True, exist_ok=True) + filename = ( + enhanced_folder_listener + / f"{scene_id}_{listener.id}_{current_scene['song']}.flac" + ) - # Clip and save - if config.soft_clip: - enhanced = np.tanh(enhanced) - n_clipped = np.sum(np.abs(enhanced) > 1.0) + # - Resample to 32 kHz sample rate + # - Clip signal + # - Convert to 16bit + # - Compress using flac + enhanced = resample(enhanced, config.sample_rate, config.enhanced_sample_rate) + clipped_signal, n_clipped = clip_signal(enhanced, config.soft_clip) if n_clipped > 0: logger.warning(f"Writing {filename}: {n_clipped} samples clipped") - np.clip(enhanced, -1.0, 1.0, out=enhanced) - signal_16 = (32768.0 * enhanced).astype(np.int16) - wavfile.write( - enhanced_folder_listener / filename, config.sample_rate, signal_16 + flac_encoder.encode( + to_16bit(clipped_signal), config.enhanced_sample_rate, filename ) diff --git a/recipes/cad1/task2/baseline/evaluate.py b/recipes/cad1/task2/baseline/evaluate.py index a434b3f30..b3a202b6c 100644 --- a/recipes/cad1/task2/baseline/evaluate.py +++ b/recipes/cad1/task2/baseline/evaluate.py @@ -3,7 +3,6 @@ # pylint: disable=import-error from __future__ import annotations -import csv import hashlib import logging from pathlib import Path @@ -11,11 +10,12 @@ import hydra import numpy as np from omegaconf import DictConfig -from scipy.io import wavfile from tqdm import tqdm from clarity.evaluator.haaqi import compute_haaqi from clarity.utils.audiogram import Listener +from clarity.utils.flac_encoder import read_flac_signal +from clarity.utils.results_support import ResultsFile from recipes.cad1.task2.baseline.audio_manager import AudioManager from recipes.cad1.task2.baseline.baseline_utils import ( load_hrtf, @@ -28,81 +28,6 @@ logger = logging.getLogger(__name__) -class ResultsFile: - """A utility class for writing results to a CSV file. - - Attributes: - file_name (str): The name of the file to write results to. - """ - - def __init__(self, file_name): - """Initialize the ResultsFile instance. - - Args: - file_name (str): The name of the file to write results to. - """ - self.file_name = file_name - - def write_header(self): - """Write the header row to the CSV file.""" - with open(self.file_name, "w", encoding="utf-8") as csv_f: - csv_writer = csv.writer( - csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - "scene", - "song", - "genre", - "listener", - "score", - "haaqi_left", - "haaqi_right", - ] - ) - - # pylint: disable=too-many-arguments - def add_result( - self, - scene: str, - song: str, - genre: str, - listener: str, - score: float, - haaqi_left: float, - haaqi_right: float, - ): - """Add a result to the CSV file. - - Args: - scene (str): The name of the scene that the result is for. - song (str): The name of the song that the result is for. - genre (str): The genre of the song that the result is for. - listener (str): The name of the listener who submitted the result. - score (float): The combined score for the result. - haaqi_left (float): The HAAQI score for the left channel. - haaqi_right (float): The HAAQI score for the right channel. - """ - - logger.info(f"The combined score for scene {scene}: {score:.4f}") - - with open(self.file_name, "a", encoding="utf-8") as csv_f: - csv_writer = csv.writer( - csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) - csv_writer.writerow( - [ - scene, - song, - genre, - listener, - str(score), - str(haaqi_left), - str(haaqi_right), - ] - ) - - def set_scene_seed(scene: str): """Set a seed that is unique for the given song based on the last 8 characters of the 'md5' @@ -228,10 +153,26 @@ def run_calculate_audio_quality(config: DictConfig) -> None: enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") + scores_headers = [ + "scene", + "song", + "genre", + "listener", + "score", + "haaqi_left", + "haaqi_right", + ] + + results_file_name = "scores.csv" + if config.evaluate.batch_size > 1: + results_file_name = ( + f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv" + ) + results_file = ResultsFile( - f"scores_{config.evaluate.batch}-{config.evaluate.batch_size}.csv" + file_name=results_file_name, + header_columns=scores_headers, ) - results_file.write_header() # Initialize acoustic scene model car_scene_acoustic = CarSceneAcoustics( @@ -264,15 +205,14 @@ def run_calculate_audio_quality(config: DictConfig) -> None: # Load enhanced signal enhanced_folder = Path("enhanced_signals") / config.evaluate.split - enhanced_song_id = f"{scene_id}_{listener.id}_{current_scene['song']}" - enhanced_song_path = ( - enhanced_folder / f"{listener.id}" / f"{enhanced_song_id}.wav" + # Read WAV enhanced signal using scipy.io.wavfile + enhanced_signal, enhanced_sample_rate = read_flac_signal( + enhanced_folder + / f"{listener.id}" + / f"{scene_id}_{listener.id}_{current_scene['song']}.flac" ) - # Read WAV enhanced signal using scipy.io.wavfile - enhanced_sample_rate, enhanced_signal = wavfile.read(enhanced_song_path) - enhanced_signal = enhanced_signal / 32768.0 - assert enhanced_sample_rate == config.sample_rate + assert enhanced_sample_rate == config.enhanced_sample_rate # Evaluate scene aq_score_l, aq_score_r = evaluate_scene( @@ -290,13 +230,15 @@ def run_calculate_audio_quality(config: DictConfig) -> None: # Compute combined score and save score = np.mean([aq_score_r, aq_score_l]) results_file.add_result( - scene_id, - current_scene["song"], - current_scene["song_path"].split("/")[-2], - listener.id, - score=float(score), - haaqi_left=aq_score_l, - haaqi_right=aq_score_r, + { + "scene": scene_id, + "song": current_scene["song"], + "genre": current_scene["song_path"].split("/")[-2], + "listener": listener.id, + "score": float(score), + "haaqi_left": aq_score_l, + "haaqi_right": aq_score_r, + } ) diff --git a/recipes/cad1/task2/baseline/test.py b/recipes/cad1/task2/baseline/test.py new file mode 100644 index 000000000..8e5e26638 --- /dev/null +++ b/recipes/cad1/task2/baseline/test.py @@ -0,0 +1,67 @@ +""" Run the dummy enhancement. """ +# pylint: disable=too-many-locals +# pylint: disable=import-error +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +import hydra +from omegaconf import DictConfig + +from recipes.cad1.task2.baseline.enhance import enhance as enhance_set + +logger = logging.getLogger(__name__) + + +def pack_submission( + team_id: str, + root_dir: str | Path, + base_dir: str | Path = ".", +) -> None: + """ + Pack the submission files into an archive file. + + Args: + team_id (str): Team ID. + root_dir (str | Path): Root directory of the archived file. + base_dir (str | Path): Base directory to archive. Defaults to ".". + """ + # Pack the submission files + logger.info(f"Packing submission files for team {team_id}...") + shutil.make_archive( + f"submission_{team_id}", + "zip", + root_dir=root_dir, + base_dir=base_dir, + ) + + +@hydra.main(config_path="", config_name="config") +def enhance(config: DictConfig) -> None: + """ + Run the music enhancement. + The baseline system is a dummy processor that returns the input signal. + + Args: + config (dict): Dictionary of configuration options for enhancing music. + """ + enhance_set(config) + + pack_submission( + team_id=config.team_id, + root_dir=Path("enhanced_signals"), + base_dir=config.evaluate.split, + ) + + logger.info("Evaluation complete.!!") + logger.info( + f"Please, submit the file submission_{config.team_id}.zip to the challenge " + "using the link provided. Thank you.!!" + ) + + +# pylint: disable = no-value-for-parameter +if __name__ == "__main__": + enhance() diff --git a/tests/recipes/cad1/task1/baseline/test_enhance_task1.py b/tests/recipes/cad1/task1/baseline/test_enhance_task1.py index b8cf8e8db..7a5f4e592 100644 --- a/tests/recipes/cad1/task1/baseline/test_enhance_task1.py +++ b/tests/recipes/cad1/task1/baseline/test_enhance_task1.py @@ -1,10 +1,10 @@ """Tests for the enhance module""" +# pylint: disable=import-error from pathlib import Path import numpy as np import pytest import torch -from omegaconf import DictConfig from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB from clarity.enhancer.compressor import Compressor @@ -44,16 +44,17 @@ def test_map_to_dict(): @pytest.mark.parametrize( - "separation_model", + "separation_model,normalise", [ - pytest.param("demucs"), - pytest.param("openunmix", marks=pytest.mark.slow), + (pytest.param("demucs"), True), + (pytest.param("openunmix", marks=pytest.mark.slow), True), ], ) -def test_decompose_signal(separation_model): +def test_decompose_signal(separation_model, normalise): """Takes a signal and decomposes it into VDBO sources using the HDEMUCS model""" np.random.seed(123456789) # Load Separation Model + separation_model = separation_model.values[0] if separation_model == "demucs": model = HDEMUCS_HIGH_MUSDB.get_model().double() elif separation_model == "openunmix": @@ -67,27 +68,19 @@ def test_decompose_signal(separation_model): duration = 0.5 signal = np.random.uniform(size=(1, 2, int(sample_rate * duration))) - # config - config = DictConfig( - { - "sample_rate": sample_rate, - "separator": { - "model": "demucs", - "sources": ["drums", "bass", "other", "vocals"], - }, - } - ) # Call the decompose_signal function and check that the output has the expected keys cfs = np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]) audiogram = Audiogram(levels=np.ones(9), frequencies=cfs) listener = Listener(audiogram, audiogram) output = decompose_signal( - config, - model, - signal, - sample_rate, - device, - listener, + model=model, + model_sample_rate=sample_rate, + signal=signal, + signal_sample_rate=sample_rate, + device=device, + sources_list=["drums", "bass", "other", "vocals"], + listener=listener, + normalise=normalise, ) expected_results = np.load( RESOURCES / f"test_enhance.test_decompose_signal_{separation_model}.npy", diff --git a/tests/recipes/cad1/task1/baseline/test_evaluate.py b/tests/recipes/cad1/task1/baseline/test_evaluate.py index 544094eb4..8a3977d69 100644 --- a/tests/recipes/cad1/task1/baseline/test_evaluate.py +++ b/tests/recipes/cad1/task1/baseline/test_evaluate.py @@ -8,42 +8,14 @@ from scipy.io import wavfile from clarity.utils.audiogram import Audiogram, Listener +from clarity.utils.flac_encoder import FlacEncoder from recipes.cad1.task1.baseline.evaluate import ( - ResultsFile, _evaluate_song_listener, make_song_listener_list, set_song_seed, ) -def test_results_file(tmp_path): - """Test the class ResultsFile""" - results_file = tmp_path / "results.csv" - result_file = ResultsFile(results_file.as_posix()) - result_file.write_header() - result_file.add_result( - listener_id="My listener", - song="My favorite song", - score=0.9, - instruments_scores={ - "left_bass": 0.8, - "right_bass": 0.8, - "left_drums": 0.9, - "right_drums": 0.9, - "left_other": 0.8, - "right_other": 0.8, - "left_vocals": 0.95, - "right_vocals": 0.95, - }, - ) - with open(results_file, encoding="utf-8") as file: - contents = file.read() - assert ( - "My favorite song,My listener,0.9,0.8,0.8,0.9,0.9,0.8,0.8,0.95,0.95" - in contents - ) - - @pytest.mark.parametrize( "song,expected_result", [("my favorite song", 83), ("another song", 3)], @@ -88,10 +60,11 @@ def test_make_song_listener_list(): "punk_is_not_dead", "my_music_listener", { + "stem_sample_rate": 44100, + "sample_rate": 44100, "evaluate": {"set_random_seed": True}, "path": {"music_dir": None}, - "sample_rate": 44100, - "nalr": {"sample_rate": 44100}, + "nalr": {"fs": 44100}, }, "test", { @@ -102,14 +75,14 @@ def test_make_song_listener_list(): } }, { - "left_drums": 0.14229422779265366, - "right_drums": 0.15044965630960655, - "left_bass": 0.1333774836344767, - "right_bass": 0.14541827476097585, - "left_other": 0.16310480582621734, - "right_other": 0.15427835764875864, - "left_vocals": 0.12291980372806624, - "right_vocals": 0.1368378217706031, + "left_drums": 0.14229280292204488, + "right_drums": 0.15044867874762802, + "left_bass": 0.13337685099485902, + "right_bass": 0.14541734646032817, + "left_other": 0.16310385596493193, + "right_other": 0.1542791489799909, + "left_vocals": 0.12291878218281638, + "right_vocals": 0.13683790592287856, }, ) ], @@ -143,22 +116,22 @@ def test_evaluate_song_listener( instruments = ["drums", "bass", "other", "vocals"] # Create reference and enhanced wav samples + flac_encoder = FlacEncoder() for lr_instrument in list(expected_results.keys()): # enhanced signals are mono enh_file = ( enhanced_folder / f"{listener.id}" / f"{song}" - / f"{listener.id}_{song}_{lr_instrument}.wav" + / f"{listener.id}_{song}_{lr_instrument}.flac" ) enh_file.parent.mkdir(exist_ok=True, parents=True) + with open(Path(enh_file).with_suffix(".txt"), "w", encoding="utf-8") as file: + file.write("1.0") # Using very short 100 ms signals to speed up the test - wavfile.write( - enh_file, - 44100, - np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768, - ) + enh_signal = np.random.uniform(-1, 1, 4410).astype(np.float32) * 32768 + flac_encoder.encode(enh_signal.astype(np.int16), config.sample_rate, enh_file) for instrument in instruments: # reference signals are stereo @@ -183,7 +156,7 @@ def test_evaluate_song_listener( # Combined score assert isinstance(combined_score, float) assert combined_score == pytest.approx( - 0.14358505393391977, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance + 0.14358442152193474, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance ) # Per instrument score diff --git a/tests/recipes/cad1/task2/baseline/test_baseline_utils.py b/tests/recipes/cad1/task2/baseline/test_baseline_utils.py index 2ae029d68..7f0d527b6 100644 --- a/tests/recipes/cad1/task2/baseline/test_baseline_utils.py +++ b/tests/recipes/cad1/task2/baseline/test_baseline_utils.py @@ -1,6 +1,7 @@ """Test for baseline_utils.py""" from pathlib import Path +# pylint: disable=import-error import librosa import numpy as np import pytest @@ -35,7 +36,7 @@ def test_load_listeners_and_scenes(): { "path": { "scenes_file": (RESOURCES / "scenes.json").as_posix(), - "listeners_train_file": (RESOURCES / "listeners.json").as_posix(), + "listeners_file": (RESOURCES / "listeners.json").as_posix(), "scenes_listeners_file": ( RESOURCES / "scenes_listeners.json" ).as_posix(), diff --git a/tests/recipes/cad1/task2/baseline/test_enhance_task2.py b/tests/recipes/cad1/task2/baseline/test_enhance_task2.py index 27e31d830..aaa340c46 100644 --- a/tests/recipes/cad1/task2/baseline/test_enhance_task2.py +++ b/tests/recipes/cad1/task2/baseline/test_enhance_task2.py @@ -1,4 +1,6 @@ """Test the enhance module.""" +# pylint: disable=import-error + from pathlib import Path import numpy as np diff --git a/tests/utils/test_signal_processing.py b/tests/utils/test_signal_processing.py index 0b7942c99..f0bb22470 100644 --- a/tests/utils/test_signal_processing.py +++ b/tests/utils/test_signal_processing.py @@ -9,6 +9,7 @@ denormalize_signals, normalize_signal, resample, + to_16bit, ) @@ -166,6 +167,41 @@ def test_compute_rms(): ) +@pytest.mark.parametrize( + "signal,soft_clip,expected_output", + [ + ( + np.array([0.5, 2.0, -1.5, 0.8]), + True, + (np.array([0.46211716, 0.96402758, -0.90514825, 0.66403677]), 0), + ), + (np.array([0.5, 2.0, -1.5, 0.8]), False, (np.array([0.5, 1.0, -1.0, 0.8]), 2)), + ], +) +def test_clip_signal(signal, soft_clip, expected_output): + """Test the clip_signal function""" + # Test with soft clip + output = clip_signal(signal, soft_clip=soft_clip) + assert np.allclose(output[0], expected_output[0]) + assert output[1] == expected_output[1] + + +@pytest.mark.parametrize( + "signal,expected_output", + [ + (np.array([0.5, 0.8, 0.2, 1.0]), np.array([16384, 26214, 6553, 32767])), + (np.array([-0.5, -0.8, -0.2, -1.0]), np.array([-16384, -26214, -6553, -32768])), + (np.array([0.5, -0.8, 0.2, -1.0]), np.array([16384, -26214, 6553, -32768])), + ], +) +def test_to_16bit(signal, expected_output): + """Test the to_16bit function""" + # Test with positive signal + output = to_16bit(signal) + print(output) + assert np.allclose(output, expected_output) + + @pytest.mark.parametrize( "input_sample_rate, input_shape, output_sample_rate, output_shape", [