In [1]:
import torch
import numpy as np
import pandas as pd
import librosa as lr
import soundfile as sf
import matplotlib.pyplot as plt
import time
import os.path

from librosa import display as lrd
from concurrent.futures import ProcessPoolExecutor
from collections import deque
from functools import partial
from torch.multiprocessing import Process, Queue, cpu_count
from queue import Empty

from torch.utils.data import DataLoader, ConcatDataset, random_split, Subset
from asteroid.data import TimitDataset
from asteroid.data.utils import CachedWavSet, FixedMixtureSet
from tqdm import trange, tqdm

from torch import optim
from pytorch_lightning import Trainer, loggers as pl_loggers
from asteroid_filterbanks.transforms import mag

from asteroid import DCUNet, DCCRNet, DPRNNTasNet, ConvTasNet
from asteroid.utils.notebook_utils import show_wav

from asteroid.metrics import get_metrics

%load_ext autoreload
%autoreload 2

In [2]:
TIMIT_DIR_8kHZ = '/import/vision-eddydata/dm005_tmp/TIMIT_8kHZ'
TEST_NOISE_DIR = '../../../datasets/noises-test-drones'
SAMPLE_RATE    = 8000
TEST_SNRS      = [-30, -25, -20, -15, -10, -5, 0]
SEED           = 42

In [3]:
timit_test_clean = TimitDataset(TIMIT_DIR_8kHZ, subset='test', sample_rate=SAMPLE_RATE, with_path=False)
noises_test = CachedWavSet(TEST_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

Precaching audio: 100%|██████████| 2/2 [00:00<00:00, 136.48it/s]


In [4]:
timit_test = FixedMixtureSet(timit_test_clean, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True)

In [5]:
def _eval(batch, metrics, including='output', sample_rate=8000):
    mix, clean, estimate, snr = batch
    metrics = get_metrics(mix.numpy(), clean.numpy(), estimate.numpy(),
                          sample_rate=sample_rate, metrics_list=metrics, including=including)
    metrics['snr'] = snr[0].item()
    return metrics

def data_feed_process(queue, signal_queue, model, test_set):
    loader = DataLoader(test_set, num_workers=2)
    
    if model is not None:
        model = model.cuda()
    
    def model_eval(mix):
        if model is None:
            return mix
        else:
            return model(mix.cuda()).squeeze(1).detach().cpu()
    
#     print('FEEDING STARTED')
    
    i = 0
    for mix, clean, snr in loader:
#         print('lets eval! ' + str(i))
        enh = model_eval(mix)
#         print('enh ready! ' + str(i))
        queue.put((mix, clean, enh, snr))
#         print('put into queue! ' + str(i))
        i += 1
        
#     print('FEEDING DONE')
    # wait for a signal to end the process
    signal_queue.get()
        
def eval_process(proc_idx, input_queue, output_queue, **kwargs):
#     print(f'WORKER {proc_idx} STARTED')
    
    i = 0
    while True:
        try:
            batch = input_queue.get()
            if batch is None:
#                 print(f'WORKER {proc_idx} IS DONE')
                break
            else:
                res = _eval(batch, **kwargs)
#                 print(f'WORKER {proc_idx}: EVALED {i}')
                output_queue.put(res)
#                 print(f'WORKER {proc_idx} RESULT SENT {i}')
                i += 1
        except Empty:
            print(f'WORKER {proc_idx} empty! {i}')
            time.sleep(0.1)

def evaluate_model(model, test_set, num_workers=None, metrics=['pesq', 'stoi', 'si_sdr'],
                   sample_rate=8000, max_queue_size=100):
    
    if num_workers is None:
        num_workers = cpu_count()
    
    df = pd.DataFrame(columns=['snr']+metrics)
    ds_len = len(test_set)
    including = 'input' if model is None else 'output'
    
    signal_queue = Queue()
    input_queue = Queue(maxsize=max_queue_size)
    output_queue = Queue(maxsize=max_queue_size)
    
    feed_pr = Process(target=data_feed_process, args=(input_queue, signal_queue, model, test_set))
    feed_pr.start()
        
    eval_prs = []
    for i in range(num_workers-1):
        pr = Process(target=eval_process, args=(i, input_queue, output_queue), kwargs={
            'metrics': metrics,
            'including': including,
            'sample_rate': sample_rate
        })
        pr.start()
        eval_prs.append(pr)
    
    for i in tqdm(range(ds_len), 'Evaluating and calculating scores'):
        res = output_queue.get()
        df = df.append(res, ignore_index=True)
        
    signal_queue.put(None)
    for pr in eval_prs:
        input_queue.put(None)
    
    feed_pr.join()
    for pr in eval_prs:
        pr.join()

    return df

def evaluate_input(*args, **kwargs):
    return evaluate_model(None, *args, **kwargs)

In [6]:
torch.multiprocessing.set_sharing_strategy('file_system')

In [7]:
metrics_names = {
    'pesq': 'PESQ',
    'stoi': 'STOI',
    'si_sdr': 'SI-SDR',
}

model_labels = {
    'input': 'Original noisy input',
    'dcunet_20': 'DCUNet-20',
    'dccrn': 'DCCRN',
    'dprnn': 'DPRNN',
    'conv_tasnet': 'Conv-TasNet',
}

def plot_results(dfs, figsize=(15, 5), metrics=['pesq', 'stoi', 'si_sdr'],
                 plot_name=None): 
    
    fig, axes = plt.subplots(nrows=1, ncols=len(metrics), figsize=figsize)
    
    for model_name, df in dfs.items():
        scores = df.groupby('snr').agg({
            metric: ['mean', 'std', 'count'] for metric in metrics
        })
        
        line_kwargs = {'marker': 'o', 'alpha': 0.8}
        fill_kwargs = {}
        if model_name == 'input':
            line_kwargs = {'c': 'black', 'ls': '--'}
            fill_kwargs = {'color': 'black'}
        
        for i, metric in enumerate(metrics):
            plt.sca(axes[i])
            means = scores[metric]['mean']
            stds = scores[metric]['std'].values / np.sqrt(scores[metric]['count'].values) * 3
            xs = means.index
            plt.plot(xs, means, label=model_labels[model_name], **line_kwargs)
            plt.fill_between(xs, means - stds, means + stds, alpha=0.2, **fill_kwargs)
    
    for i, metric in enumerate(metrics):
        plt.sca(axes[i])
        plt.grid(which='both')
        plt.title(metrics_names[metric])
        plt.xlabel('SNR, dB')
        if i == 0:
            plt.legend()
    
    if plot_name is not None:
        plt.savefig(plot_name, bbox_inches='tight')
    
    plt.show()

### Models evaluation

In [8]:
models = {
    'input': None,
    'dcunet_20': DCUNet.from_pretrained('../../../workspace/models/dcunet_20_random_v2.pt'),
    'dccrn': DCCRNet.from_pretrained('../../../workspace/models/dccrn_random_v1.pt'),
    'dprnn': DPRNNTasNet.from_pretrained('../../../workspace/models/dprnn_model.pt'),
    'conv_tasnet': ConvTasNet.from_pretrained('../../../workspace/models/convtasnet_model.pt'),
}

In [None]:
results_dfs = {}

for model_name, model in models.items():
    print(f'Evaluating {model_labels[model_name]}')
    csv_path = f'../../../workspace/eval_results/{model_name}.csv'
    
    if os.path.isfile(csv_path):
        print('Results already available')
        df = pd.read_csv(csv_path)
    else:
        df = evaluate_model(model, timit_test)
        df.to_csv(csv_path, index=False)

    results_dfs[model_name] = df
    
plot_results(results_dfs)

Evaluating Original noisy input
Results already available
Evaluating DCUNet-20


Evaluating and calculating scores: 100%|██████████| 23520/23520 [24:43<00:00, 15.86it/s]


Evaluating DCCRN


Evaluating and calculating scores: 100%|██████████| 23520/23520 [32:27<00:00, 12.08it/s]


Evaluating DPRNN


Evaluating and calculating scores: 100%|██████████| 23520/23520 [33:20<00:00, 11.76it/s]


Evaluating Conv-TasNet


Evaluating and calculating scores:  96%|█████████▋| 22685/23520 [24:02<00:45, 18.20it/s]