# Posterior sampling algorithms for RVAE-based speech enhancement

In [12]:
%load_ext autoreload
%autoreload 2

import os
from datetime import datetime
from src.utils import get_logger, EvalMetrics
from src.se.enhancement import enhance
import soundfile as sf
from IPython.display import Audio
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Test signals

In [13]:
mix_file = './data/x.wav' # noisy speech signal
clean_file = './data/s.wav' # clean speech signal
enhance_dir = './data/'

## Define SE parameters

In [23]:
nmf_rank = 8 # NMF rank
num_iter = 100 # Number of EM iterations
num_E_step = 1 # Number of posterior sampling iterations at each E-step
lr = 5e-3 # Learning rate of posterior sampling methods
verbose = True # Show the progress
ckpt_path = "./trained_model/RVAE_final_epoch294.pt"
algo_type = "ldem" # Could be one of "ldem" (fast & efficient), "vem", "mhem", or "malaem".
optimizer = "sgld" # This is only for "ldem". It defines the type of optimizer. It could be one of ["sgld", "psgld", "adam"]. "sgld" works much better than "adam".
device = "cuda"

In [24]:
# Init evaluation
eval_metrics = EvalMetrics(metric="all")

recon_file = os.path.join(enhance_dir, f"enh_{algo_type}.wav")

# SE parameters
se_params = { 'mix_file':mix_file,
              'video_file':[],
              'output_file':recon_file,
              'ckpt_path': ckpt_path,
              'algo_type':algo_type,
              'nmf_rank':nmf_rank,
              'niter':num_iter,
              'nepochs_E_step':num_E_step,
              'verbose':verbose,
              'device':device,
              "optimizer": optimizer,
              'lr':lr,}

# Enhance algo, clean_file only used if we run monitor performance
x_recon, time_consume = enhance(
    mix_file=mix_file,
    video_file=[],
    output_file=recon_file,
    clean_file="",
    algo_type=se_params["algo_type"],
    ckpt_path=se_params["ckpt_path"],
    nmf_rank=se_params["nmf_rank"],
    niter=se_params["niter"],
    nepochs_E_step=se_params["nepochs_E_step"],
    optimizer=se_params["optimizer"],
    lr=se_params["lr"],
    device=se_params["device"],
    verbose=se_params["verbose"],
)

x_ref, fs_x = sf.read(clean_file)
x_noisy, fs_x = sf.read(mix_file)

# Output metrics:
(
    rmse_out,
    sisdr_out,
    pesq_out,
    pesq_wb_out,
    pesq_nb_out,
    estoi_out,
) = eval_metrics.eval(x_est=x_recon, x_ref=x_ref, fs=fs_x)
# Input metrics:
(
    rmse_in,
    sisdr_in,
    pesq_in,
    pesq_wb_in,
    pesq_nb_in,
    estoi_in,
) = eval_metrics.eval(x_est=x_noisy, x_ref=x_ref, fs=fs_x)

# Input metrics
log_message = " Input metrics: \t len: {:.4f} rmse: {:.4f}\t sisdr: {:.2f}\t pypesq: {:.2f}\t estoi: {:.2f}\t time: {:.4f}s".format(
    len(x_ref) / fs_x,
    rmse_in,
    sisdr_in,
    pesq_in,
    estoi_in,
    time_consume,
)
print(log_message)
# Output metrics
log_message = "Output metrics: \t len: {:.4f} rmse: {:.4f}\t sisdr: {:.2f}\t pypesq: {:.2f}\t estoi: {:.2f}\t time: {:.4f}s".format(
    len(x_ref) / fs_x,
    rmse_out,
    sisdr_out,
    pesq_out,
    estoi_out,
    time_consume,
)

print(log_message)

iter: 1/100 - loss: -0.0464
iter: 11/100 - loss: -0.8137
iter: 21/100 - loss: -0.8449
iter: 31/100 - loss: -0.8575
iter: 41/100 - loss: -0.8664
iter: 51/100 - loss: -0.8734
iter: 61/100 - loss: -0.8787
iter: 71/100 - loss: -0.8815
iter: 81/100 - loss: -0.8848
iter: 91/100 - loss: -0.8810
 Input metrics: 	 len: 4.8000 rmse: 0.0177	 sisdr: -0.54	 pypesq: 1.61	 estoi: 0.36	 time: 1.9699s
Output metrics: 	 len: 4.8000 rmse: 0.0074	 sisdr: 9.86	 pypesq: 2.42	 estoi: 0.47	 time: 1.9699s


## Play the audio signals

### Clean speech signal

In [9]:
Audio(data=x_ref, rate=fs_x, autoplay=False)

### Input (unprocessed) speech signal

In [10]:
Audio(data=x_noisy, rate=fs_x, autoplay=False)

### Enhanced speech signal

In [11]:
Audio(data=x_recon, rate=fs_x, autoplay=False)