In [1]:
import torch
import numpy as np
import pandas as pd
import librosa as lr
import soundfile as sf

from librosa import display as lrd
from concurrent.futures import ProcessPoolExecutor
from collections import deque
from queue import Queue, Empty

from torch.utils.data import DataLoader, ConcatDataset, random_split
from asteroid.data import TimitDataset
from asteroid.data.utils import CachedWavSet, FixedMixtureSet
from tqdm import trange, tqdm

from torch import optim
from pytorch_lightning import Trainer, loggers as pl_loggers
from asteroid_filterbanks.transforms import mag

from asteroid import DCUNet, DCCRNet
from asteroid.utils.notebook_utils import show_wav

from asteroid.metrics import get_metrics

%load_ext autoreload
%autoreload 2

In [2]:
TIMIT_DIR_8kHZ = '/import/vision-eddydata/dm005_tmp/TIMIT_8kHZ'
TEST_NOISE_DIR = '../../../datasets/noises-test-drones'
SAMPLE_RATE    = 8000
TEST_SNRS      = [-30, -25, -20, -15, -10, -5, 0]
SEED           = 42

In [3]:
timit_test_clean = TimitDataset(TIMIT_DIR_8kHZ, subset='test', sample_rate=SAMPLE_RATE, with_path=False)
noises_test = CachedWavSet(TEST_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

Precaching audio: 100%|██████████| 2/2 [00:00<00:00, 202.02it/s]


In [4]:
timit_test = FixedMixtureSet(timit_test_clean, noises_test, snrs=TEST_SNRS, random_seed=SEED, with_snr=True)

In [6]:
def _eval(mix, clean, estimate, snr, metrics, including='output', sample_rate=8000):
    metrics = get_metrics(mix, clean, estimate, sample_rate=sample_rate,
                          metrics_list=metrics, including=including)
    metrics['snr'] = snr[0]
    return metrics

def flush_deque(d):
    while len(d) > 0:
        yield d.popleft()

def evaluate_input(test_set, num_workers=5, metrics=['pesq', 'stoi', 'si_sdr'],
                   sample_rate=8000, max_pending=1000):
    
    df = pd.DataFrame(columns=['snr']+metrics)
    loader = DataLoader(test_set, num_workers=num_workers)
    
    pendings = deque()
    with ProcessPoolExecutor(num_workers) as pool:
        for mix, clean, snr in tqdm(loader, 'Calculating scores (submitting)'):
            
            mix = mix.numpy()
            clean = clean.numpy()
            pendings.append(pool.submit(_eval, mix, clean, None, snr, metrics,
                                        including='input', sample_rate=sample_rate))
        
            if len(pendings) >= max_pending:
                for pending in tqdm(flush_deque(pendings), 'Collecting pending jobs', total=len(pendings)):
                    res = pending.result()
                    df = df.append(res, ignore_index=True)

    return df

In [7]:
input_scores = evaluate_input(timit_test)

Calculating scores (submitting):   4%|▍         | 973/23520 [00:03<01:21, 277.30it/s]
Collecting pending jobs:   0%|          | 0/1000 [00:00<?, ?it/s][A
Collecting pending jobs:   1%|▏         | 13/1000 [00:00<00:07, 129.88it/s][A
Collecting pending jobs:   2%|▏         | 22/1000 [00:00<00:08, 111.20it/s][A
Collecting pending jobs:   3%|▎         | 30/1000 [00:00<00:09, 99.52it/s] [A
Collecting pending jobs:   4%|▍         | 39/1000 [00:00<00:09, 96.40it/s][A
Collecting pending jobs:   6%|▌         | 55/1000 [00:00<00:13, 69.67it/s][A
Collecting pending jobs:   6%|▌         | 62/1000 [00:01<00:25, 37.25it/s][A
Collecting pending jobs:   7%|▋         | 67/1000 [00:01<00:49, 19.04it/s][A
Collecting pending jobs:   7%|▋         | 72/1000 [00:01<00:45, 20.50it/s][A
Collecting pending jobs:   8%|▊         | 77/1000 [00:02<00:41, 22.47it/s][A
Collecting pending jobs:   8%|▊         | 81/1000 [00:02<00:53, 17.18it/s][A
Collecting pending jobs:   9%|▊         | 87/1000 [00:02<00:49

Traceback (most recent call last):
  File "<ipython-input-6-dae390d4a1c5>", line 28, in evaluate_input
    res = pending.result()
  File "/homes/dm005/conda_env/lib/python3.8/concurrent/futures/_base.py", line 434, in result
    self._condition.wait(timeout)
  File "/homes/dm005/conda_env/lib/python3.8/threading.py", line 302, in wait
    waiter.acquire()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/homes/dm005/conda_env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-a5233579e56a>", line 1, in <module>
    input_scores = evaluate_input(timit_test)
  File "<ipython-input-6-dae390d4a1c5>", line 29, in evaluate_input
    df = df.append(res, ignore_index=True)
  File "/homes/dm005/conda_env/lib/python3.8/concurrent/futures/_base.py", line 636, in __exit__
    self.shutdown(wait=True)
  Fi

TypeError: object of type 'NoneType' has no len()

In [None]:
# def evaluate_model(model, test_set, num_workers=10, metrics=['pesq', 'stoi', 'si_sdr']):
#     df = pd.DataFrame(columns=['snr'] + metrics)
#     if model is None: