From 4e3a0100de8e7809fe4f62004589ac0820c99195 Mon Sep 17 00:00:00 2001 From: danielha Date: Fri, 13 Jul 2018 16:44:07 +0200 Subject: [PATCH] changed default in wpe command line interface, refactored power estimation in tf_wpe --- examples/WPE_Tensorflow_online.ipynb | 2 +- nara_wpe/tf_wpe.py | 93 +++++++++++++++++++++++----- nara_wpe/wpe.py | 5 +- 3 files changed, 81 insertions(+), 19 deletions(-) diff --git a/examples/WPE_Tensorflow_online.ipynb b/examples/WPE_Tensorflow_online.ipynb index 20f8321..94e3b96 100644 --- a/examples/WPE_Tensorflow_online.ipynb +++ b/examples/WPE_Tensorflow_online.ipynb @@ -146,7 +146,7 @@ " Q_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels * taps))\n", " G_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels))\n", " \n", - " results = online_wpe_step(Y_tf, get_power_online(Y_tf), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n", + " results = online_wpe_step(Y_tf, get_power_online(tf.transpose(Y_tf, (1, 0, 2))), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n", " for Y_step in tqdm(aquire_framebuffer()):\n", " feed_dict = {Y_tf: Y_step, Q_tf: Q, G_tf: G}\n", " Z, Q, G = session.run(results, feed_dict)\n", diff --git a/nara_wpe/tf_wpe.py b/nara_wpe/tf_wpe.py index ce3990a..44f4324 100644 --- a/nara_wpe/tf_wpe.py +++ b/nara_wpe/tf_wpe.py @@ -57,43 +57,106 @@ def _slice(x): def get_power_online(signal): - """Calculates power over last to frames for `signal` + """Calculates power for `signal` Args: - signal (tf.Tensor): Single frequency signal with shape (T, F, D). + signal (tf.Tensor): Signal with shape (F, D, T). Returns: - tf.Tensor: Inverse power with shape (F,) + tf.Tensor: Power with shape (F,) """ - power_estimate = tf.real(signal) ** 2 + tf.imag(signal) ** 2 - power_estimate += tf.pad( - power_estimate, - ((1, 0), (0, 0), (0, 0)) - )[:-1, :] - power_estimate /= 2 - power_estimate = tf.reduce_mean(power_estimate, axis=(0, -1)) + power_estimate = get_power(signal) + power_estimate = tf.reduce_mean(power_estimate, axis=-1) return power_estimate -def get_power_inverse(signal, channel_axis=0): +def get_power_inverse(signal): """Calculates inverse power for `signal` Args: signal (tf.Tensor): Single frequency signal with shape (D, T). - channel_axis (int): Axis of the channel dimension. Will be averaged. - + psd_context: context for power estimation Returns: tf.Tensor: Inverse power with shape (T,) """ - power = tf.reduce_mean( - tf.real(signal) ** 2 + tf.imag(signal) ** 2, axis=channel_axis) + power = get_power(signal) eps = 1e-10 * tf.reduce_max(power) inverse_power = tf.reciprocal(tf.maximum(power, eps)) return inverse_power +def get_power(signal, axis=-2): + """Calculates power for `signal` + + Args: + signal (tf.Tensor): Single frequency signal with shape (D, T) or (F, D, T). + axis: reduce_mean axis + Returns: + tf.Tensor: Power with shape (T,) or (F, T) + + """ + power = tf.real(signal) ** 2 + tf.imag(signal) ** 2 + power = tf.reduce_mean(power, axis=axis) + + return power + + +#def get_power(signal, psd_context=0): +# """ +# Calculates power for single frequency signal. +# In case psd_context is an tuple the two values +# are describing the left and right hand context. +# +# Args: +# signal: (D, T) +# psd_context: tuple or int +# """ +# shape = tf.shape(signal) +# if len(signal.get_shape()) == 2: +# signal = tf.reshape(signal, (1, shape[0], shape[1])) +# +# power = tf.reduce_mean( +# tf.real(signal) ** 2 + tf.imag(signal) ** 2, +# axis=-2 +# ) +# +# if psd_context is not 0: +# if isinstance(psd_context, tuple): +# context = psd_context[0] + 1 + psd_context[1] +# else: +# context = 2 * psd_context + 1 +# psd_context = (psd_context, psd_context) +# +# power = tf.pad( +# power, +# ((0, 0), (psd_context[0], psd_context[1])), +# mode='constant' +# ) +# print(power) +# power = tf.nn.convolution( +# power, +# tf.ones(context), +# padding='VALID' +# )[psd_context[1]:-psd_context[0]] +# +# denom = tf.nn.convolution( +# tf.zeros_like(power) + 1., +# tf.ones(context), +# padding='VALID' +# )[psd_context[1]:-psd_context[0]] +# print(power) +# power /= denom +# +# elif psd_context == 0: +# pass +# else: +# raise ValueError(psd_context) +# +# return tf.squeeze(power, axis=0) + + def get_correlations(Y, inverse_power, taps, delay): """Calculates weighted correlations of a window of length taps diff --git a/nara_wpe/wpe.py b/nara_wpe/wpe.py index 900a4bd..1e3f43b 100644 --- a/nara_wpe/wpe.py +++ b/nara_wpe/wpe.py @@ -489,7 +489,7 @@ def get_power_online(signal): """ Args: - signal : Single frequency signal with shape (F, D, T). + signal : Signal with shape (F, D, T). Returns: Inverse power with shape (F,) @@ -884,8 +884,7 @@ def perform_filter_operation_v5(Y, Y_tilde, filter_matrix): ) @click.option( '--file_template', - default='AMI_WSJ20-Array1-{}_T10c0201.wav', - help='Audio example. Full path required' + help='Audio example. Full path required. Included example: AMI_WSJ20-Array1-{}_T10c0201.wav' ) @click.option( '--taps_frequency_dependent',