Skip to content

Commit

Permalink
fix wpe_v8 for ndim > 2. Add testcases for multi freq and batched mul…
Browse files Browse the repository at this point in the history
…ti freq
  • Loading branch information
boeddeker committed Oct 9, 2018
1 parent c29ff3d commit d75f4b6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nara_wpe/wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def wpe_v8(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu
"""
v8 is faster than v7 and offers an optional batch mode.
"""
if Y.ndim == 2:
ndim = Y.ndim
if ndim == 2:
return wpe_v6(
Y,
taps=taps,
Expand Down
66 changes: 66 additions & 0 deletions tests/test_wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,69 @@ def test_wpe_v6_vs_v7(self):
actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
tc.assert_raises(AssertionError, tc.assert_array_equal, desired, actual)

@retry(5)
def test_wpe_v8(self):
desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid')
actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid')
actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

@retry(5)
def test_wpe_multi_freq(self):
desired = wpe.wpe_v0(self.Y, self.K, self.delay, statistics_mode='full')
desired = [desired, desired]
actual = wpe.wpe_v0(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
desired = [desired, desired]
actual = wpe.wpe_v7(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
desired = [desired, desired]
actual = wpe.wpe_v6(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
desired = [desired, desired]
actual = wpe.wpe_v8(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

@retry(5)
def test_wpe_batched_multi_freq(self):
def to_batched_multi_freq(x):
return np.array([
[x, x*2],
[x*3, x*4],
[x*5, x*6],
])
Y_batched_multi_freq = to_batched_multi_freq(self.Y)

tc.assert_raises(NotImplementedError, wpe.wpe_v0, Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')

desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
desired = to_batched_multi_freq(desired)
actual = wpe.wpe_v7(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
desired = to_batched_multi_freq(desired)
actual = wpe.wpe_v6(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
desired = to_batched_multi_freq(desired)
actual = wpe.wpe_v8(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
tc.assert_allclose(actual, desired, atol=1e-10)

0 comments on commit d75f4b6

Please sign in to comment.