In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
from nara_wpe import project_root
import soundfile as sf
from nara_wpe import wpe
from nara_wpe.utils import stft
from nara_wpe.utils import istft
from tqdm import tqdm
from librosa.core.audio import resample
import IPython
import matplotlib.pyplot as plt

# Minimal example with random data

In [None]:
import numpy as np

def aquire_audio_data():
    D, T = 4, 10000
    y = np.random.normal(size=(D, T))
    return y

In [None]:
from nt.utils.timer import TimerDictEntry

## Numpy

In [None]:
import numpy as np
from nara_wpe.wpe import wpe
from nara_wpe.utils import stft, istft

y = aquire_audio_data()
Y = stft(y)
Y = Y.transpose(2, 0, 1)
with TimerDictEntry(style='float') as t:
    Z = wpe(Y)
print(t)
z_np = istft(Z.transpose(1, 2, 0))

## TensorFlow

In [None]:
import tensorflow as tf
from nara_wpe.tf_wpe import wpe
from nara_wpe.utils import stft, istft

y = aquire_audio_data()
Y = stft(y).transpose(2, 0, 1)
with tf.Session() as session:
    Y_tf = tf.placeholder(
        tf.complex128, shape=(None, None, None))
    Z_tf = wpe(Y_tf)
    with TimerDictEntry(style='float') as t:
        Z = session.run(Z_tf, {Y_tf: Y})
    print(t)
    with TimerDictEntry(style='float') as t:
        Z = session.run(Z_tf, {Y_tf: Y})
    print(t)
z_tf = istft(Z.transpose(1, 2, 0))

# Example with real audio recordings

In [None]:
import numpy as np
from nara_wpe import project_root
import soundfile as sf
from nara_wpe.wpe import wpe_v8 as wpe
from nara_wpe.utils import stft
from nara_wpe.utils import istft
from tqdm import tqdm
from librosa.core.audio import resample
import IPython
import matplotlib.pyplot as plt

In [None]:
channels = 8
sampling_rate = 16000
stft_size, stft_shift = 512, 128
delay = 3
iterations = 5
K = 10

In [None]:
file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'
signal_list = [
    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]
    for d in range(channels)
]
y = np.stack(signal_list, axis=0)

Y = stft(y, size=stft_size, shift=stft_shift)
Z = wpe(Y.transpose(2, 0, 1)).transpose(1, 2, 0)

z = istft(Z, size=stft_size, shift=stft_shift)

In [None]:
IPython.display.Audio(z[0], rate=sampling_rate)

In [None]:
X = np.copy(Y)
D, T, F = Y.shape
for f in tqdm(range(F), total=F):
    X[:, :, f] = wpe.wpe_v8(Y[None, :, :, f], K=K, delay=delay, iterations=iterations)[0]

x = istft(X, size=stft_size, shift=stft_shift)

In [None]:
Y.shape

In [None]:
file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'
signal_list = [
    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]
    for d in range(channels)
]
signal_list = [resample(x_, 16000, sampling_rate) for x_ in signal_list]
y = np.stack(signal_list, axis=0)

center_frequencies = get_stft_center_frequencies(stft_size, sampling_rate)

Y = stft(y, size=stft_size, shift=stft_shift)

X = np.copy(Y)
D, T, F = Y.shape
for f in tqdm(range(1), total=1):
    X = wpe.wpe_v8(Y.transpose((2, 0, 1)), K=K, delay=delay, iterations=iterations).transpose(1, 2, 0)

x = istft(X, size=stft_size, shift=stft_shift)

In [None]:
y.shape

In [None]:
y.shape[1] / sampling_rate

# One of the input channels

In [None]:
IPython.display.Audio(y[0], rate=sampling_rate)

In [None]:
Y.shape

In [None]:
plt.figure(figsize=(20, 8))
plt.imshow(20 * np.log10(np.abs(Y[0, 200:400, :])).T, origin='lower')
plt.colorbar()
plt.xlabel('')
plt.show()

# Dereverberated signal

In [None]:
IPython.display.Audio(z[0], rate=sampling_rate)

In [None]:
plt.figure(figsize=(20, 8))
plt.imshow(20 * np.log10(np.abs(Z[0, 200:400, :])).T, origin='lower')
plt.colorbar()
plt.show()