Skip to content

Commit

Permalink
And its done! MD file needs writing.
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-costas committed Jan 25, 2018
1 parent fc731bf commit a105130
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 31 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -14,6 +14,8 @@ outputs/metrics/*
outputs/states/*
!outputs/states/.dummy.txt

*.wav

# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
Expand Down
88 changes: 60 additions & 28 deletions helpers/data_feeder.py
Expand Up @@ -89,7 +89,7 @@ def epoch_it():


def data_feeder_testing(window_size, fft_size, hop_size, seq_length, context_length,
batch_size, debug):
batch_size, debug, sources_list=None):
"""Provides an iterator over the testing examples.
:param window_size: The window size to be used for the time-frequency transformation.
Expand All @@ -106,11 +106,17 @@ def data_feeder_testing(window_size, fft_size, hop_size, seq_length, context_len
:type batch_size: int
:param debug: A flag to indicate debug
:type debug: bool
:param sources_list: The file list provided for using the MaD-TwinNet.
:type sources_list: list[str]
:return: An iterator that will provide the input and target values.\
The iterator yields (mix, mix magnitude, mix phase, voice true, bg true) values.
:rtype: callable
"""
sources_list = _get_files_lists('testing')[-1]
if sources_list is None:
usage_case = False
sources_list = _get_files_lists('testing')[-1]
else:
usage_case = True
hamming_window = hamming(window_size, True)

def testing_it():
Expand All @@ -120,7 +126,7 @@ def testing_it():
sources_parent_path=sources_list[index],
window_values=hamming_window, fft_size=fft_size, hop=hop_size,
seq_length=seq_length, context_length=context_length,
batch_size=batch_size
batch_size=batch_size, usage_case=usage_case
)

if debug:
Expand All @@ -130,7 +136,8 @@ def testing_it():


def data_process_results_testing(index, voice_true, bg_true, voice_predicted,
window_size, mix, mix_magnitude, mix_phase, hop, context_length):
window_size, mix, mix_magnitude, mix_phase, hop,
context_length, output_file_name=None):
"""Calculates SDR and SIR and creates the resulting audio files.
:param index: The index of the current source/track.
Expand All @@ -153,6 +160,11 @@ def data_process_results_testing(index, voice_true, bg_true, voice_predicted,
:type hop: int
:param context_length: The context length in frames.
:type context_length: int
:param output_file_name: The output file name for the predicted voice\
and background music. If this argument is not
None, then the function just synthesizes the
voice and the background music, and saves them.
:type output_file_name: list[str] | None
:return: The values of SDR and SIR for each of the frames in\
the current track, for both voice and background music.
:rtype: (list[numpy.core.multiarray.ndarray], list[numpy.core.multiarray.ndarray])
Expand All @@ -164,29 +176,42 @@ def data_process_results_testing(index, voice_true, bg_true, voice_predicted,

# Removing the samples that no estimation exists
mix = mix[context_length * hop:]
voice_true = voice_true[context_length * hop:]
bg_true = bg_true[context_length * hop:]

min_len = min(len(voice_true), len(voice_hat))
if output_file_name is None:
voice_true = voice_true[context_length * hop:]
bg_true = bg_true[context_length * hop:]
min_len = min(len(voice_true), len(voice_hat))
example_index = index + 1
else:
voice_true = None
bg_true = None
example_index = None
min_len = min(len(mix), len(voice_hat))

# Background music estimation
bg_hat = mix[:min_len] - voice_hat[:min_len]

example_index = index + 1
if output_file_name is None:
voice_hat_path = output_audio_paths['voice_predicted'].format(p=example_index)
bg_hat_path = output_audio_paths['bg_predicted'].format(p=example_index)
wav_write(voice_true, file_name=output_audio_paths['voice_true'].format(p=example_index), **wav_quality)
wav_write(bg_true, file_name=output_audio_paths['bg_true'].format(p=example_index), **wav_quality)
wav_write(mix, file_name=output_audio_paths['mix'].format(p=example_index), **wav_quality)

wav_write(voice_true, file_name=output_audio_paths['voice_true'].format(p=example_index), **wav_quality)
wav_write(voice_hat, file_name=output_audio_paths['voice_predicted'].format(p=example_index), **wav_quality)
# Metrics calculation
sdr, sir = _get_me_sdr_and_sir(bss_eval.bss_eval_images_framewise(
[voice_true[:min_len], bg_true[:min_len]],
[voice_hat[:min_len], bg_hat[:min_len]]
))

wav_write(bg_true, file_name=output_audio_paths['bg_true'].format(p=example_index), **wav_quality)
wav_write(bg_hat, file_name=output_audio_paths['bg_predicted'].format(p=example_index), **wav_quality)
else:
voice_hat_path = output_file_name[0]
bg_hat_path = output_file_name[1]

wav_write(mix, file_name=output_audio_paths['mix'].format(p=example_index), **wav_quality)
sdr = None
sir = None

# Metrics calculation
sdr, sir = _get_me_sdr_and_sir(bss_eval.bss_eval_images_framewise(
[voice_true[:min_len], bg_true[:min_len]],
[voice_hat[:min_len], bg_hat[:min_len]]
))
wav_write(voice_hat, file_name=voice_hat_path, **wav_quality)
wav_write(bg_hat, file_name=bg_hat_path, **wav_quality)

return sdr, sir

Expand Down Expand Up @@ -347,7 +372,7 @@ def _get_data_training(current_set, set_size, mixtures_list, sources_list,


def _get_data_testing(sources_parent_path, window_values, fft_size, hop,
seq_length, context_length, batch_size):
seq_length, context_length, batch_size, usage_case):
"""Gets the actual input and output data for testing.
:param sources_parent_path: The parent path of the sources
Expand All @@ -364,17 +389,24 @@ def _get_data_testing(sources_parent_path, window_values, fft_size, hop,
:type context_length: int
:param batch_size: The batch size.
:type batch_size: int
:param usage_case: Flag to indicate that currently we are just using it.
:type usage_case: bool
:return: The actual input and target value.
:rtype: numpy.core.multiarray.ndarray
"""
bass = wav_read(os.path.join(sources_parent_path, 'bass.wav'), mono=False)[0]
drums = wav_read(os.path.join(sources_parent_path, 'drums.wav'), mono=False)[0]
others = wav_read(os.path.join(sources_parent_path, 'other.wav'), mono=False)[0]
voice = wav_read(os.path.join(sources_parent_path, 'vocals.wav'), mono=False)[0]

bg_true = np.sum(bass + drums + others, axis=-1) * 0.5
mix = np.sum(bass + drums + others + voice, axis=-1) * 0.5
voice_true = np.sum(voice, axis=-1) * 0.5
if not usage_case:
bass = wav_read(os.path.join(sources_parent_path, 'bass.wav'), mono=False)[0]
drums = wav_read(os.path.join(sources_parent_path, 'drums.wav'), mono=False)[0]
others = wav_read(os.path.join(sources_parent_path, 'other.wav'), mono=False)[0]
voice = wav_read(os.path.join(sources_parent_path, 'vocals.wav'), mono=False)[0]

bg_true = np.sum(bass + drums + others, axis=-1) * 0.5
voice_true = np.sum(voice, axis=-1) * 0.5
mix = np.sum(bass + drums + others + voice, axis=-1) * 0.5
else:
mix = wav_read(sources_parent_path, mono=True)[0]
voice_true = None
bg_true = None

mix_magnitude, mix_phase = stft(mix, window_values, fft_size, hop)

Expand Down
12 changes: 9 additions & 3 deletions helpers/settings.py
Expand Up @@ -20,9 +20,12 @@
'testing_output_string_all',
'training_constants',
'wav_quality',
'hyper_parameters'
'hyper_parameters',
'usage_output_string_per_example',
'usage_output_string_total'
]


debug = False
_debug_suffix = '_debug' if debug else ''

Expand Down Expand Up @@ -62,8 +65,8 @@
}

metrics_paths = {
'sdr': os.path.join(_metrics_path, 'sdr{}.pckl'.format(_debug_suffix)),
'sir': os.path.join(_metrics_path, 'sir{}.pckl'.format(_debug_suffix))
'sdr': os.path.join(_metrics_path, 'sdr{}_p2.pckl'.format(_debug_suffix)),
'sir': os.path.join(_metrics_path, 'sir{}_p2.pckl'.format(_debug_suffix))
}

output_states_path = {
Expand All @@ -87,6 +90,9 @@
'Median SIR:{sir:6.2f} dB | ' \
'Total time:{t:6.2f} sec(s)'

usage_output_string_per_example = '-- File {f} processed. Time: {t:6.2f} sec(s)'
usage_output_string_total = '-- All files processed. Total time: {t:6.2f} sec(s)'

# Process constants
training_constants = {
'epochs': 2 if debug else 100,
Expand Down
2 changes: 2 additions & 0 deletions scripts/testing.py
Expand Up @@ -4,6 +4,8 @@
"""Testing process module.
"""

from __future__ import print_function

import pickle
import time

Expand Down
2 changes: 2 additions & 0 deletions scripts/training.py
Expand Up @@ -4,6 +4,8 @@
"""Training process module.
"""

from __future__ import print_function

import time

import torch
Expand Down

0 comments on commit a105130

Please sign in to comment.