In [1]:
colab = False
import librosa
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pysepm
import IPython.display as ipd
import pandas as pd
import torch
from models import DWaveNet
import os

Measures:
* PESQ &rarr; Perceptual evaluation of speech quality, using the wide-band (from 0.5 to 4.5).
* STOI % &rarr; Short-Time Objective Intelligibility (from 0 to 100).
* CSIG &rarr; Mean opinion score (MOS) prediction of the signal distortion attending only to the speech signal (from 1 to 5).
* CBAK &rarr; MOS prediction of the intrusiveness of background noise (from 1 to 5).
* COVL &rarr; MOS prediction of the overall effect (from 1 to 5).
* SSNR &rarr; Segmental SNR (from 0 to inf).

In [2]:
model = DWaveNet(in_channels = 1, target_field_length = None, num_layers = 30,
                 num_stacks = 3, residual_channels = 128,
                 gate_channels = 128, skip_out_channels = 128,
                 last_channels=(2048, 256)) 

In [3]:
if colab:
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  device
else:  
  os.environ["CUDA_VISIBLE_DEVICES"] = '7'
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  device

In [4]:
pretrain_model_pth = "../log/fine_tuning/wavenetLABVoxCeleb/model_best.pth.tar"
if pretrain_model_pth is not None:
    if os.path.isfile(pretrain_model_pth):
        print('loading pre-trained model from %s' % pretrain_model_pth)
        model_dict = model.state_dict()
        checkpoint = torch.load(pretrain_model_pth, map_location = lambda storage, loc: storage) # load for cpu
        model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})
    else:
        print("===> no checkpoint found at '{}'".format(pretrain_model_pth))

loading pre-trained model from ../log/fine_tuning/wavenetLABVoxCeleb/model_best.pth.tar


In [5]:
model = model.to(device)
pesq = []
sig = []
bak = []
ovl = []
ssnr = []
results = {}
model.eval()
with torch.no_grad():
    count = 0
    for line in open('/opt/kaldi/egs/Signal-denoising-in-the-wild/data/denoising_test.scp'):
        clean_path, noisy_path = line.rstrip().split()
        clean, fs = librosa.load(clean_path, sr= 16000)
        noisy, fs = librosa.load(noisy_path, sr = 16000)
        noisy = torch.from_numpy(noisy).unsqueeze(0).unsqueeze(0).to(device)
        denoised = model(noisy).cpu().squeeze(0).squeeze(0).numpy()
        pesq.append(pysepm.pesq(clean, denoised, fs)[1])
        tmp = pysepm.composite(clean, denoised, fs)
        sig.append(tmp[0])
        bak.append(tmp[1])
        ovl.append(tmp[2])
        ssnr.append(pysepm.fwSNRseg(clean, denoised, fs))
        count += 1
        if count % 200 == 0:
            print("{} done..".format(count))

results['pesq'] = np.mean(pesq)
results['sig'] = np.mean(sig)
results['bak'] = np.mean(bak)
results['ovl'] = np.mean(ovl)
results['ssnr'] = np.mean(ssnr)
print(results)

200 done..
400 done..
600 done..
800 done..
{'pesq': 2.082445452803547, 'sig': 3.165552952229102, 'bak': 2.7924844344255333, 'ovl': 2.5917959924147356, 'ssnr': 9.464051444553299}


wavenet4:
'pesq': 1.37942910295667, 'sig': 2.34887768774676, 'bak': 2.2893793356515033, 'ovl': 1.8110202581679395, 'ssnr': 7.650913123112853

wavenet4_fine_tuned: 'pesq': 1.185409572952002, 'sig': 1.5681731302899924, 'bak': 2.012994033959743, 'ovl': 1.3119164729902952, 'ssnr': 5.413532761913475

wavenetLAB: 'pesq': 1.9355006084858792, 'sig': 2.8321720834561392, 'bak': 2.7752958417860873, 'ovl': 2.359445965705248, 'ssnr': 11.05712180000405

wavenetLAB_fine_tuned: 'pesq': 1.4400822405965583, 'sig': 2.1563515918133485, 'bak': 2.3556137490605633, 'ovl': 1.7505921768273165, 'ssnr': 8.666557766188415

wavenetLABVoxCeleb: 'pesq': 2.1357708132093394, 'sig': 3.2044454869361156, 'bak': 2.8171929147604873, 'ovl': 2.6386285444938586, 'ssnr': 9.671704434413316

wavenetLABVoxCeleb_fine_tuned: 'pesq': 2.082445452803547, 'sig': 3.165552952229102, 'bak': 2.7924844344255333, 'ovl': 2.5917959924147356, 'ssnr': 9.464051444553299

In [2]:
pesq = []
sig = []
bak = []
ovl = []
ssnr = []
results = {}

count = 0
for line in open('/opt/kaldi/egs/Signal-denoising-in-the-wild/data/denoising_test.scp'):
    clean_path, noisy_path = line.rstrip().split()
    clean, fs = librosa.load(clean_path, sr= 16000)
    noisy, fs = librosa.load(noisy_path, sr = 16000)
    pesq.append(pysepm.pesq(clean, noisy, fs)[1])
    tmp = pysepm.composite(clean, noisy, fs)
    sig.append(tmp[0])
    bak.append(tmp[1])
    ovl.append(tmp[2])
    ssnr.append(pysepm.fwSNRseg(clean, noisy, fs))
    count += 1
    if count % 200 == 0:
        print("{} done..".format(count))

results['pesq'] = np.mean(pesq)
results['sig'] = np.mean(sig)
results['bak'] = np.mean(bak)
results['ovl'] = np.mean(ovl)
results['ssnr'] = np.mean(ssnr)
print(results)

200 done..
400 done..
600 done..
800 done..
{'pesq': 1.9707932136591197, 'sig': 3.340068088685546, 'bak': 2.442452392195827, 'ovl': 2.6286600366661497, 'ssnr': 10.9741821476963}
