Skip to content

Commit

Permalink
changed default in wpe command line interface, refactored power estim…
Browse files Browse the repository at this point in the history
…ation in tf_wpe
  • Loading branch information
danielhkl committed Jul 13, 2018
1 parent 5377fb6 commit 4e3a010
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/WPE_Tensorflow_online.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
" Q_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels * taps))\n",
" G_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels))\n",
" \n",
" results = online_wpe_step(Y_tf, get_power_online(Y_tf), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n",
" results = online_wpe_step(Y_tf, get_power_online(tf.transpose(Y_tf, (1, 0, 2))), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n",
" for Y_step in tqdm(aquire_framebuffer()):\n",
" feed_dict = {Y_tf: Y_step, Q_tf: Q, G_tf: G}\n",
" Z, Q, G = session.run(results, feed_dict)\n",
Expand Down
93 changes: 78 additions & 15 deletions nara_wpe/tf_wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,43 +57,106 @@ def _slice(x):


def get_power_online(signal):
"""Calculates power over last to frames for `signal`
"""Calculates power for `signal`
Args:
signal (tf.Tensor): Single frequency signal with shape (T, F, D).
signal (tf.Tensor): Signal with shape (F, D, T).
Returns:
tf.Tensor: Inverse power with shape (F,)
tf.Tensor: Power with shape (F,)
"""
power_estimate = tf.real(signal) ** 2 + tf.imag(signal) ** 2
power_estimate += tf.pad(
power_estimate,
((1, 0), (0, 0), (0, 0))
)[:-1, :]
power_estimate /= 2
power_estimate = tf.reduce_mean(power_estimate, axis=(0, -1))
power_estimate = get_power(signal)
power_estimate = tf.reduce_mean(power_estimate, axis=-1)
return power_estimate


def get_power_inverse(signal, channel_axis=0):
def get_power_inverse(signal):
"""Calculates inverse power for `signal`
Args:
signal (tf.Tensor): Single frequency signal with shape (D, T).
channel_axis (int): Axis of the channel dimension. Will be averaged.
psd_context: context for power estimation
Returns:
tf.Tensor: Inverse power with shape (T,)
"""
power = tf.reduce_mean(
tf.real(signal) ** 2 + tf.imag(signal) ** 2, axis=channel_axis)
power = get_power(signal)
eps = 1e-10 * tf.reduce_max(power)
inverse_power = tf.reciprocal(tf.maximum(power, eps))
return inverse_power


def get_power(signal, axis=-2):
"""Calculates power for `signal`
Args:
signal (tf.Tensor): Single frequency signal with shape (D, T) or (F, D, T).
axis: reduce_mean axis
Returns:
tf.Tensor: Power with shape (T,) or (F, T)
"""
power = tf.real(signal) ** 2 + tf.imag(signal) ** 2
power = tf.reduce_mean(power, axis=axis)

return power


#def get_power(signal, psd_context=0):
# """
# Calculates power for single frequency signal.
# In case psd_context is an tuple the two values
# are describing the left and right hand context.
#
# Args:
# signal: (D, T)
# psd_context: tuple or int
# """
# shape = tf.shape(signal)
# if len(signal.get_shape()) == 2:
# signal = tf.reshape(signal, (1, shape[0], shape[1]))
#
# power = tf.reduce_mean(
# tf.real(signal) ** 2 + tf.imag(signal) ** 2,
# axis=-2
# )
#
# if psd_context is not 0:
# if isinstance(psd_context, tuple):
# context = psd_context[0] + 1 + psd_context[1]
# else:
# context = 2 * psd_context + 1
# psd_context = (psd_context, psd_context)
#
# power = tf.pad(
# power,
# ((0, 0), (psd_context[0], psd_context[1])),
# mode='constant'
# )
# print(power)
# power = tf.nn.convolution(
# power,
# tf.ones(context),
# padding='VALID'
# )[psd_context[1]:-psd_context[0]]
#
# denom = tf.nn.convolution(
# tf.zeros_like(power) + 1.,
# tf.ones(context),
# padding='VALID'
# )[psd_context[1]:-psd_context[0]]
# print(power)
# power /= denom
#
# elif psd_context == 0:
# pass
# else:
# raise ValueError(psd_context)
#
# return tf.squeeze(power, axis=0)


def get_correlations(Y, inverse_power, taps, delay):
"""Calculates weighted correlations of a window of length taps
Expand Down
5 changes: 2 additions & 3 deletions nara_wpe/wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def get_power_online(signal):
"""
Args:
signal : Single frequency signal with shape (F, D, T).
signal : Signal with shape (F, D, T).
Returns:
Inverse power with shape (F,)
Expand Down Expand Up @@ -884,8 +884,7 @@ def perform_filter_operation_v5(Y, Y_tilde, filter_matrix):
)
@click.option(
'--file_template',
default='AMI_WSJ20-Array1-{}_T10c0201.wav',
help='Audio example. Full path required'
help='Audio example. Full path required. Included example: AMI_WSJ20-Array1-{}_T10c0201.wav'
)
@click.option(
'--taps_frequency_dependent',
Expand Down

0 comments on commit 4e3a010

Please sign in to comment.