In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import IPython
from librosa.core.audio import resample
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import time
from tqdm import tqdm
import tensorflow as tf

from nara_wpe.tf_wpe import wpe
from nara_wpe.utils import stft, istft, get_stft_center_frequencies
from nara_wpe import project_root

In [None]:
stft_options = dict(
    size=1024,
    shift=256,
    window_length=None,
    fading=True,
    pad=True,
    symmetric_window=False
)

# Minimal example with random data

In [None]:
def aquire_audio_data():
    D, T = 4, 10000
    y = np.random.normal(size=(D, T))
    return y

In [None]:
y = aquire_audio_data()
Y = stft(y, **stft_options).transpose(2, 0, 1)
with tf.Session() as session:
    Y_tf = tf.placeholder(
        tf.complex128, shape=(None, None, None))
    Z_tf = wpe(Y_tf)
    
    start = time.perf_counter()
    Z = session.run(Z_tf, {Y_tf: Y})
    end = time.perf_counter()
    print(f"Time 1: {end-start}")
    
    start = time.perf_counter()
    Z = session.run(Z_tf, {Y_tf: Y})
    end = time.perf_counter()
    print(f"Time 2: {end-start}")
z_tf = istft(Z.transpose(1, 2, 0))

# Example with real audio recordings

### Setup

In [None]:
channels = 8
sampling_rate = 16000
stft_options.update(size=512, shift=128)
delay = 3
iterations = 5
K = 10

### Audio data

In [None]:
file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'
signal_list = [
    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]
    for d in range(channels)
]
y = np.stack(signal_list, axis=0)
IPython.display.Audio(y[0], rate=sampling_rate)

### Result

In [None]:
Y = stft(y, **stft_options).transpose(2, 0, 1)
with tf.Session()as session:
    Y_tf = tf.placeholder(tf.complex128, shape=(None, None, None))
    Z_tf = wpe(Y_tf)
    Z = session.run(Z_tf, {Y_tf: Y})
z = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])
IPython.display.Audio(z[0], rate=sampling_rate)

# PSD

In [None]:
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))
im1 = ax1.imshow(20 * np.log10(np.abs(Y[:, 0, 200:400])), origin='lower')
ax1.set_xlabel('')
_ = ax1.set_title('reverberated')
im2 = ax2.imshow(20 * np.log10(np.abs(Z[:, 0, 200:400])), origin='lower')
_ = ax2.set_title('dereverberated')
cb = fig.colorbar(im1)