In [22]:
import os
import time
import asyncio
from tqdm import tqdm
# import tqdm
# import tqdm.asyncio
import numpy as np
import soundfile as sf

import torch
import torchaudio
from speechbrain.pretrained import SpectralMaskEnhancement

from pesq import pesq
device = torch.device('cuda:2')

In [2]:
enhance_model = SpectralMaskEnhancement.from_hparams(
    source="speechbrain/metricgan-plus-voicebank",
    savedir="/workspace/SE_2022/model_experiment/metricgan").to(device)
enhance_model.device = device

### Prediction

In [3]:
root = '/workspace/data/test'
output_dir = '/workspace/output_data/metricgan/test'

for flac_name in tqdm(os.listdir(root)):
    if flac_name.endswith('.flac'):
        number = flac_name.split('_')[1]
        clean_file_name = 'vocal_'+ number + '.flac'
        noise_file_path = os.path.join(root, flac_name)
        clean_file_path = os.path.join(output_dir, clean_file_name)
        noisy = enhance_model.load_audio(noise_file_path, savedir= '/workspace/data/temp').unsqueeze(0).to(device)
        clean = enhance_model.enhance_batch(noisy, lengths=torch.tensor([1.]))
        torchaudio.save(clean_file_path, clean.cpu(), 16000)

100%|██████████| 1000/1000 [00:27<00:00, 36.51it/s]


In [4]:
root = '/workspace/data/train'
output_dir = '/workspace/output_data/metricgan/train'

for flac_name in tqdm(os.listdir(root)):
    if flac_name.endswith('.flac'):
        if flac_name.split('_')[0] == 'mixed':
            number = flac_name.split('_')[1]
            clean_file_name = 'vocal_'+ number + '.flac'
            noise_file_path = os.path.join(root, flac_name)
            clean_file_path = os.path.join(output_dir, clean_file_name)
            noisy = enhance_model.load_audio(noise_file_path, savedir= '/workspace/data/temp').unsqueeze(0).to(device)
            clean = enhance_model.enhance_batch(noisy, lengths=torch.tensor([1.]))
            torchaudio.save(clean_file_path, clean.cpu(), 16000)

100%|██████████| 83656/83656 [12:33<00:00, 111.01it/s]


### Evaluation

In [7]:
async def get_pesq(noise_file_path, clean_file_path):
    noise, _ = sf.read(noise_file_path)
    clean, rate = sf.read(clean_file_path)
    return pesq(rate, clean, noise, 'wb')
    
async def get_pesq_async(noise_file_path, clean_file_path):
    try:
        noise, _ = sf.read(noise_file_path)
        clean, rate = sf.read(clean_file_path)        
        score = await pesq(rate, clean, noise, 'wb')
        return score
    except Exception as e:
        return None
# async def get_pesq_async(noise_file_path, clean_file_path):
#     try:
#         score = await get_pesq(noise_file_path, clean_file_path)
#         return score
#     except Exception as e:
#         return None

async def train_pesq_truth():

    tasks = []
    root = '/workspace/data/train'
    flac_name_list = []
    file_list = os.listdir(root)[:100]

    for flac_name in file_list:
        if flac_name.endswith('.flac'):
            if flac_name.split('_')[0] == 'mixed':
                flac_name_list.append(flac_name)
                number = flac_name.split('_')[1]
                clean_file_name = 'vocal_'+ number + '.flac'
                noise_file_path = os.path.join(root, flac_name)
                clean_file_path = os.path.join(root, clean_file_name)
                tasks.append(asyncio.create_task(get_pesq_async(noise_file_path, clean_file_path)))
    score_result = [
        await f
        for f in tqdm.asyncio.tqdm.as_completed(tasks)
    ]

    return score_result

In [8]:
start = time.perf_counter()
score_result = await train_pesq_truth()
print(time.perf_counter() - start)


100%|██████████| 100/100 [00:12<00:00,  7.88it/s]

12.768901099450886





In [None]:
score_result_final = []
error_name_list = []
for i, score in enumerate(score_result):
    if score is not None:
        score_result_final.append(score)
    else:
        error_name_list.append(flac_name_list[i])
print('train_pesq_truth:', np.array(score_result).mean())

In [15]:
root = '/workspace/data/train'
truth_score = []
name_list = os.listdir(root)[:100]

for flac_name in tqdm(name_list):
    if flac_name.endswith('.flac'):
        if flac_name.split('_')[0] == 'mixed':
            number = flac_name.split('_')[1]
            clean_file_name = 'vocal_'+ number + '.flac'
            noise_file_path = os.path.join(root, flac_name)
            clean_file_path = os.path.join(root, clean_file_name)
            noise, _ = sf.read(noise_file_path)
            clean, rate = sf.read(clean_file_path)
            truth_score.append(pesq(rate, clean, noise, 'wb'))

100%|██████████| 100/100 [00:12<00:00,  7.88it/s]
