In [None]:
import torch
import numpy as np
import pandas as pd
import librosa as lr
import soundfile as sf
import time

from librosa import display as lrd
from concurrent.futures import ProcessPoolExecutor
from collections import deque
from functools import partial
from multiprocessing import Process, Queue, cpu_count
from queue import Empty

from torch.utils.data import DataLoader, ConcatDataset, random_split
from asteroid.data import TimitDataset
from asteroid.data.utils import CachedWavSet, FixedMixtureSet
from tqdm import trange, tqdm

from torch import optim
from pytorch_lightning import Trainer, loggers as pl_loggers
from asteroid_filterbanks.transforms import mag

from asteroid import DCUNet, DCCRNet
from asteroid.utils.notebook_utils import show_wav

from asteroid.metrics import get_metrics

%load_ext autoreload
%autoreload 2

In [2]:
TIMIT_DIR_8kHZ = '../../datasets/TIMIT_8kHz'
TEST_NOISE_DIR = '../../datasets/noises-test-drones'
SAMPLE_RATE    = 8000
TEST_SNRS      = [-30, -25, -20, -15, -10, -5, 0]
SEED           = 42

In [3]:
timit_test_clean = TimitDataset(TIMIT_DIR_8kHZ, subset='test', sample_rate=SAMPLE_RATE, with_path=False)
noises_test = CachedWavSet(TEST_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

Precaching audio: 100%|██████████| 2/2 [00:00<00:00, 276.94it/s]


In [4]:
timit_test = FixedMixtureSet(timit_test_clean, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True)

In [5]:
def _eval(batch, metrics, including='output', sample_rate=8000):
    mix, clean, estimate, snr = batch
    metrics = get_metrics(mix, clean, estimate, sample_rate=sample_rate,
                          metrics_list=metrics, including=including)
    metrics['snr'] = snr[0]
    return metrics
        
def model_eval_iterator(model, test_set, num_workers=4):
    loader = DataLoader(test_set, num_workers=num_workers)
    
    if model is not None:
        model = model.cuda()
    
    def model_eval(mix):
        if model is None:
            return mix
        else:
            return model(mix.cuda()).detach().cpu()
    
    for mix, clean, snr in tqdm(loader, 'Loading/enhancing data'):
        enh = model_eval(mix)
        yield mix.numpy(), clean.numpy(), enh.numpy(), snr
            

def evaluate_model(model, test_set, num_workers=4, metrics=['pesq', 'stoi', 'si_sdr'],
                   sample_rate=8000):
    df = pd.DataFrame(columns=['snr']+metrics)
    ds_len = len(test_set)
    including = 'input' if model is None else 'output'
    eval_iter = model_eval_iterator(model, test_set, num_workers=num_workers)
    eval_func = partial(_eval, metrics=metrics, including=including, sample_rate=sample_rate)
    
    with ProcessPoolExecutor(num_workers) as pool:        
        for res in tqdm(pool.map(eval_func, eval_iter), 'Evaluating and calculating scores', total=ds_len):
            df.append(res)
            
    return df


def evaluate_input(*args, **kwargs):
    return evaluate_model(None, *args, **kwargs)

In [None]:
def _eval(batch, metrics, including='output', sample_rate=8000):
    mix, clean, estimate, snr = batch
    metrics = get_metrics(mix, clean, estimate, sample_rate=sample_rate,
                          metrics_list=metrics, including=including)
    metrics['snr'] = snr[0]
    return metrics

def data_feed_process(queue, model, test_set):
    loader = DataLoader(test_set)
    
    if model is not None:
        model = model.cuda()
    
    def model_eval(mix):
        if model is None:
            return mix
        else:
            return model(mix.cuda()).detach().cpu()
    
    for mix, clean, snr in loader:
        enh = model_eval(mix)
        queue.put((mix.numpy(), clean.numpy(), enh.numpy(), snr))
        
def eval_process(input_queue, output_queue, **kwargs):
    while True:
        try:
            batch = input_queue.get()
            res = _eval(batch, **kwargs)
            output_queue.put(res)
        except Empty:
            break

def evaluate_model(model, test_set, num_workers=None, metrics=['pesq', 'stoi', 'si_sdr'],
                   sample_rate=8000, max_queue_size=1000):
    
    if num_workers is None:
        num_workers = cpu_count()
    
    df = pd.DataFrame(columns=['snr']+metrics)
    ds_len = len(test_set)
    including = 'input' if model is None else 'output'
        
    input_queue = Queue(maxsize=max_queue_size)
    output_queue = Queue(maxsize=max_queue_size)
    
    feed_pr = Process(target=data_feed_process, args=(input_queue, model, test_set))
    feed_pr.start()
    
    eval_prs = []
    for i in range(num_workers-1):
        pr = Process(target=eval_process, args=(input_queue, output_queue), kwargs={
            'metrics': metrics,
            'including': including,
            'sample_rate': sample_rate
        })
        pr.start()
        eval_prs.append(pr)
    
    for i in tqdm(range(ds_len), 'Evaluating and calculating scores'):
        res = output_queue.get()
        df.append(res, ignore_index=True)
        
    feed_pr.join()
    for pr in eval_prs:
        pr.join()

    return df

def evaluate_input(*args, **kwargs):
    return evaluate_model(None, *args, **kwargs)

In [8]:
#input_scores = evaluate_input(timit_test, num_workers=4)

In [None]:
def plot_results(dfs, te_snr, plot_name=None, figsize=(6,4), ax=None, legend=True): 
    if ax is None:
        fig = plt.figure(figsize=figsize)
    else:
        plt.sca(ax)
        
    for model_name, df in model_names:
        labels[model_name] = label
        dfs[model_name] = calculate_pesq(workspace, speech_dir=speech_dir, te_snr=te_snr,
                                         model_name=model_name, calc_mixed=(model_name=='input'),
                                         library=library, **kwargs)
    
    pesqs = {}
    for model_name, df in dfs.items():
        pesqs[model_name] = df.groupby('snr').agg({'pesq': ['mean']})['pesq']['mean']
        
    for model_name, series in pesqs.items():
        line_kwargs = {'marker': 'o', 'alpha': 0.8}
        if model_name == 'input':
            line_kwargs = {'c': 'black', 'ls': '--'}
        plt.plot(series.index, series, label=labels[model_name], **line_kwargs)
    
    plt.grid(which='both')
    #plt.ylabel('PESQ' if library.endswith('pesq') else library.upper())
    plt.title('PESQ' if library.endswith('pesq') else library.upper())
    plt.xlabel('SNR, dB')
    #plt.title('Test noises n121-122')
    if legend:
        plt.legend()
    if plot_name is not None:
        plt.savefig(plot_name, bbox_inches='tight')
    
    if ax is None:
        plt.show()

In [None]:
# def evaluate_model(model, test_set, num_workers=10, metrics=['pesq', 'stoi', 'si_sdr']):
#     df = pd.DataFrame(columns=['snr'] + metrics)
#     if model is None: