In [None]:
from dvae.SE.speech_enhancement import SpeechEnhancement
import os
import torch
import sys
import soundfile as sf

## Test signals

In [None]:
mix_file = './data/x.wav' # noisy speech signal
speech_file = './data/s.wav' # clean speech signal
video_file = './data/v.npy' # video signal

## Define input parameters

In [None]:
verbose = True # show the progress
vae_mode = "AV-VAE" # VAE model. Non-dynamical: A-VAE, AV-VAE. Dynamical: A-DKF, AV-DKF.
algo_type = "peem" # SE algorithm. Choose one of {peem, gpeem} for non-dynamical VAE models, and one of {dpeem, gdpeem} for the dynamical versions.

save_flg = False
niter = 100 # number of EM iterations
fs = 16000 # sampling frequency

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if vae_mode == 'A-VAE':
    saved_model = './saved_model/A-VAE/A-VAE.pt'
    cfg_file = './saved_model/A-VAE/config.ini'

if vae_mode == 'AV-VAE':
    saved_model = './saved_model/AV-VAE/AV-VAE.pt'
    cfg_file = './saved_model/AV-VAE/config.ini'

if vae_mode == 'A-DKF':
    saved_model = './saved_model/A-DKF/A-DKF.pt'
    cfg_file = './saved_model/A-DKF/config.ini'

if vae_mode == 'AV-DKF':
    saved_model = './saved_model/AV-DKF/AV-DKF.pt'
    cfg_file = './saved_model/AV-DKF/config.ini'


path_model, _ = os.path.split(saved_model)
_, model_name = os.path.split(path_model)

output_dir = "./results"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Run the SE algorithm

In [None]:
se = SpeechEnhancement(saved_model = saved_model, output_dir = output_dir, nmf_rank = 8, niter = niter, device = device, save_flg = save_flg, verbose = verbose, demo = True)

# Run SE & evaluations
info = se.run([mix_file, speech_file, video_file, algo_type])

print('Input scores - SI-SDR: {} -- PESQ: {} --- STOI: {}'.format(info['input_scores'][0], info['input_scores'][1], info['input_scores'][2]))
print('Output scores - SI-SDR: {} -- PESQ: {} --- STOI: {}'.format(info['output_scores'][0], info['output_scores'][1], info['output_scores'][2]))

## Save the speech signals

In [None]:
save_dir = os.path.join(output_dir, vae_mode, algo_type)
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)

sf.write(os.path.join(save_dir,'est_speech.wav'), info["S_hat_wave"], fs)
sf.write(os.path.join(save_dir,'est_noise.wav'), info["N_hat_wave"], fs)