In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import IPython
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import time

from nara_wpe.ntt_wpe import ntt_wrapper as wpe
from nara_wpe import project_root
from nara_wpe.utils import stft, istft

# Minimal example with random data

In [None]:
def aquire_audio_data():
    D, T = 4, 10000
    y = np.random.normal(size=(D, T))
    return y

In [None]:
y = aquire_audio_data()

start = time.perf_counter()
x = wpe(y)
end = time.perf_counter()

print(f"Time: {end-start}")

# Example with real audio recordings

WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.

### Setup

In [None]:
channels = 8
sampling_rate = 16000
delay = 3
iterations = 5
taps = 10

### Audio data
Shape: (frames, channels)

In [None]:
file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'
signal_list = [
    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]
    for d in range(channels)
]
y = np.stack(signal_list, axis=0)
IPython.display.Audio(y[0], rate=sampling_rate)

### iterative WPE
The wpe function is fed with y. The STFT and ISTFT is included in the Matlab package of NTT. 

In [None]:
x = wpe(y, iterations=iterations)
IPython.display.Audio(x[0], rate=sampling_rate)

## Power spectrum 
Before and after applying NTT WPE

In [None]:
stft_options = dict(
    size=512,
    shift=128,
    window_length=None,
    fading=True,
    pad=True,
    symmetric_window=False
)
Y = stft(y, **stft_options).transpose(2, 0, 1)
X = stft(x, **stft_options).transpose(2, 0, 1)
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')
ax1.set_xlabel('frames')
_ = ax1.set_title('reverberated')
im2 = ax2.imshow(20 * np.log10(np.abs(X[ :, 0, 200:400])), origin='lower', vmin=-120, vmax=0)
ax2.set_xlabel('frames')
_ = ax2.set_title('dereverberated')
cb = fig.colorbar(im2)