In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchaudio
from df import enhance, init_df
from os.path import join as pjoin
import matplotlib.pyplot as plt
import os
import tqdm
import time
from IPython.display import Audio #listen: ipd.Audio(real.detach().cpu().numpy(), rate=FS)
import numpy as np
import scipy.signal as sig
import pandas as pd
import torchmetrics.audio as M
from speechmos import dnsmos
from datetime import datetime
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
import evaluation_metrics.calculate_intrusive_se_metrics as intru
import evaluation_metrics.NISQA.nisqa.NISQA_lib as NL
from evaluation_metrics.nisqa_utils import load_nisqa_model
from evaluation_metrics.nisqa_utils import predict_nisqa

import evaluation_metrics.calculate_phoneme_similarity as phon
from evaluation_metrics.calculate_phoneme_similarity import LevenshteinPhonemeSimilarity

from espnet2.bin.spk_inference import Speech2Embedding
import evaluation_metrics.calculate_speaker_similarity as spksim

from discrete_speech_metrics import SpeechBERTScore
import evaluation_metrics.calculate_speechbert_score as sbert

import evaluation_metrics.calculate_wer as wer

  from torchaudio.backend.common import AudioMetaData
[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
#HELPER FUCTIONS
def rep_list(short, long):
    #repeat a list until the length of a longer one
    reps = int(np.ceil(len(long) / len(short)))
    short *= reps
    short = short[:len(long)]
    return short
    
def plot_tensor(x):
    plt.plot(x.cpu().detach().numpy())



def extend_signal(signal, target_length):
    """
    Extend a signal by repeating it if it's shorter than the target length.
    
    Args:
    signal (torch.Tensor): Input signal.
    target_length (int): Desired length of the extended signal.

    Returns:
    torch.Tensor: Extended signal.
    """
    current_length = signal.size(0)
    if current_length < target_length:
        repetitions = target_length // current_length
        remainder = target_length % current_length
        extended_signal = signal.repeat(repetitions)
        if remainder > 0:
            extended_signal = torch.cat((extended_signal, signal[:remainder]), dim=0)
        return extended_signal
    else:
        return signal

def load_audio(apath):
    audio, fs = torchaudio.load(apath)
    if fs != FS:
        #print('resampling')
        resampler = torchaudio.transforms.Resample(fs, FS)
        audio = resampler(audio)    
    if len(audio.shape) > 1:
            audio = audio[0,:]
    return audio

def power(signal):
    return np.mean(signal**2)
    
    

In [3]:
FS = 48000
DURATION = 4 #time in seconds of the eval chunk

# names of the model folders (checkpoints/..) and aliases
TRAINRIR_NAMES = {'D01_sb_none_NH_mono': 'singleband' , 'D02_mb_none_NH_mono': 'multiband', 
            'D03_mb_rec_NH_left': 'recdirectivity'}#, 'D05_mb_srcrec_NH_left': 'recsourcedirectivity',
            #'D00_DNS5': 'DNS5', 'D09_SSmp3d_left' : 'soundspaces'}

use_gpu = True
if torch.cuda.is_available() and use_gpu:
    TORCH_DEVICE = "cuda"
else:
    TORCH_DEVICE = "cpu"

batch_size = 1
num_workers = 8
reverberant_noises = True
speech_path = '/home/ubuntu/Data/DFN/textfiles/readspeech_set.txt'
noise_path = '/home/ubuntu/Data/DFN/textfiles/test_set_noise.txt'
dns_mos_path = '/home/ubuntu/enric/DNS-Challenge/DNSMOS/DNSMOS'
rir_paths = ['/home/ubuntu/enric/guso_interspeech24/real_rirs.txt']
rir_path = rir_paths[0]

In [9]:
class DFN_dataset(Dataset):
    def __init__(self, speech_path, noise_path, rir_path, reverberant_noises):
        # we store the textfile path in the class
        print('Initializing dataset...')
        self.speech_path = speech_path
        self.noise_path = noise_path
        self.rir_path = rir_path
        self.reverberant_noises = reverberant_noises
        
        # load speech wav paths from the textfile
        self.speech_paths = []
        with open(speech_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                self.speech_paths.append(line.rstrip()) 
        print('speech set loaded. contains '+str(len(self.speech_paths)) +' files.')
        '''
        errors = []
        with open('dns_test_errors.txt', 'r') as file:
            lines = file.readlines()
            for line in lines:
                errors.append(line.rstrip()) 
        self.speech_paths = [item for item in self.speech_paths if item not in errors]
        '''
        # we filter out all speech that does not come from read_speech
        #self.speech_paths = [item for item in self.speech_paths if item.split('/')[7]=='read_speech']
        
        self.snrs = np.random.uniform(low = 0, high = 30, size = len(self.speech_paths))
        # load noise paths
        self.noise_paths = []
        with open(noise_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                self.noise_paths.append(line.rstrip()) 

        # load rir paths
        self.rir_paths = []
        with open(rir_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                self.rir_paths.append(line.rstrip()) 

        self.noise_paths = rep_list(self.noise_paths, self.speech_paths)
        self.rir_paths = rep_list(self.rir_paths, self.speech_paths)
        print('All paths loaded.')
        
    def __len__(self):
        return len(self.speech_paths)


    def __getitem__(self, idx):
        # GENERATE THE CLEAN/NOISY PAIR
        clean = load_audio(self.speech_paths[idx])
        # handle corrupt 
        if len(clean) >= FS*DURATION:
            speech_nrgy = torch.mean(clean[:FS*DURATION]**2) #we only use 4 first seconds for speed purpose
        else:
            speech_nrgy = torch.mean(clean **2)
        if speech_nrgy == 0:
            clean = load_audio(self.speech_paths[0])

        noise = load_audio(self.noise_paths[int(idx % len(self.noise_paths))])

        # handle corrupt rir
        try:
            rir = load_audio(self.rir_paths[int(idx % len(self.rir_paths))])
        except:
            rir = torch.zeros(FS)
            rir[300] = 1.
        # handle silent rir
        rir_nrgy = torch.mean(rir**2)
        if rir_nrgy == 0:
            #print('silent rir')
            rir = torch.zeros(FS)
            rir[300] = 1.


        # we extend speech and noise if too short
        if len(clean) < FS * DURATION:
            clean = extend_signal(clean, FS*DURATION)
        if len(noise) < FS * DURATION:
            noise = extend_signal(noise, FS*DURATION)

        # back to numpy for easy conv
        clean = clean.numpy()
        noise = noise.numpy()
        rir = rir.numpy()
            
        # we choose the signal chunk with more energy (to avoid silent chunks)
        nchunks = len(clean) // (FS*DURATION)
        chunks = np.split(clean[: FS * DURATION * nchunks], nchunks)
        powers = np.array([power(x) for x in chunks])
        clean = clean[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]
        
        nchunks = len(noise) // (FS*DURATION)
        chunks = np.split(noise[: FS * DURATION * nchunks], nchunks)
        powers = np.array([power(x) for x in chunks])
        noise = noise[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]

        #handle silent noise
        noise_nrgy = power(noise)
        if noise_nrgy == 0.:
            #print('silent noise sample, using white noise')
            noise = np.random.randn( FS * DURATION )

        # we set the SNR
        ini_snr = 10 * np.log10(power(clean) / power(noise))
        noise_gain_db = ini_snr - self.snrs[idx]
        noise *= np.power(10, noise_gain_db/20)

        # we normalize to 0.9 if mixture is close to clipping
        clips = np.max(np.abs(clean + noise))
        if clips >= 0.9:
            clips /= 0.9
            noise /= clips
            clean /= clips
        # or to -18dBfs if smaller than that:
        elif clips <= 10**(-18/20):
            clips /= 10**(-18/20)
            noise /= clips 
            clean /= clips    

        # apply rir 
        revspeech = sig.fftconvolve(clean, rir, 'full')
        # synchronize reverberant with anechoic
        lag = np.where(np.abs(rir) >= 0.5*np.max(np.abs(rir)))[0][0]

        revspeech = revspeech[lag:FS*DURATION + lag]

        # enforce energy conservation
        revspeech *= np.sqrt(power(clean) / power(revspeech)) 

        # apply RIR to noise too if needed
        if self.reverberant_noises:
            rnoise = sig.fftconvolve(noise, rir, 'full')
            rnoise = rnoise[lag:FS*DURATION + lag]
            rnoise *= np.sqrt(power(noise) / power(rnoise))
            noise = rnoise
        noisy = revspeech + noise
        
        # check for Nans
        if np.any(np.isnan(noisy)):
            print('noisy nan')
        if np.any(np.isnan(clean)):
            print('clean nan')
        noisy = torch.from_numpy(noisy)
        clean = torch.from_numpy(clean)
        meta = [self.speech_paths[idx], self.noise_paths[int(idx % len(self.noise_paths))], self.rir_paths[int(idx % len(self.rir_paths))], self.snrs[idx].item()]
        return noisy.float(), clean.float(), meta

In [10]:
dataset = DFN_dataset(speech_path, noise_path, rir_path, reverberant_noises)


Initializing dataset...
speech set loaded. contains 41194 files.
All paths loaded.


In [11]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) 


In [12]:
for x in tqdm.tqdm(dataloader):
        noisy, clean, meta = x

100%|████████████████████████████████████| 41194/41194 [05:06<00:00, 134.32it/s]


In [13]:
noisy

tensor([[ 0.0002,  0.0003,  0.0003,  ..., -0.0557, -0.0535, -0.0524]])

In [14]:
clean

tensor([[-0.0011, -0.0016, -0.0014,  ..., -0.0097, -0.0094, -0.0077]])

In [15]:
meta

[('/home/ubuntu/Data/DNS-Challenge/datasets_fullband/clean_fullband/vctk_wav48_silence_trimmed/p262/p262_312_mic2.wav',),
 ('/home/ubuntu/Data/DNS-Challenge/datasets_fullband/noise_fullband/cfFBEzId-iI.wav',),
 ('/home/ubuntu/Data/OpenAIR/falkland-palace-royal-tennis-court/mono/falkland_tennis_court_omni.wav',),
 tensor([13.1425], dtype=torch.float64)]

In [None]:
'''
errors = []
with open('dns_test_errors.txt', 'r') as file:
    lines = file.readlines()
    for line in lines:
        errors.append(line.rstrip()) 
speech_paths = [item for item in speech_paths if item not in errors]
'''
# we filter out all speech that does not come from read_speech
#speech_paths = [item for item in speech_paths if item.split('/')[7]=='read_speech']

snrs = np.random.uniform(low = 0, high = 30, size = len(speech_paths))
# load noise paths
noise_paths = []
with open(noise_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        noise_paths.append(line.rstrip()) 

# load rir paths
rir_paths = []
with open(rir_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        rir_paths.append(line.rstrip()) 

noise_paths = rep_list(noise_paths, speech_paths)
rir_paths = rep_list(rir_paths, speech_paths)
print('All paths loaded.')

In [None]:
idx=0
# GENERATE THE CLEAN/NOISY PAIR
clean = load_audio(speech_paths[idx])

In [None]:
clean.shape

In [None]:
len(clean) >= FS*DURATION

In [None]:
torch.mean(clean[:FS*DURATION]**2)

In [None]:
# handle weird case where speech is silence
if len(clean) >= FS*DURATION:
    speech_nrgy = torch.mean(clean[:FS*DURATION]**2)
else:
    speech_nrgy = torch.mean(clean **2)
if speech_nrgy == 0:
    clean = load_audio(speech_paths[0])

noise = load_audio(noise_paths[int(idx % len(noise_paths))])

# handle corrupt rir
try:
    rir = load_audio(rir_paths[int(idx % len(rir_paths))])
except:
    rir = torch.zeros(FS)
    rir[300] = 1.
# handle silent rir
rir_nrgy = torch.mean(rir**2)
if rir_nrgy == 0:
    #print('silent rir')
    rir = torch.zeros(FS)
    rir[300] = 1.


# we extend speech and noise if too short
if len(clean) < FS * DURATION:
    clean = extend_signal(clean, FS*DURATION)
if len(noise) < FS * DURATION:
    noise = extend_signal(noise, FS*DURATION)

# back to numpy for easy conv
clean = clean.numpy()
noise = noise.numpy()
rir = rir.numpy()
    
# we choose the signal chunk with more energy (to avoid silent chunks)
nchunks = len(clean) // (FS*DURATION)
chunks = np.split(clean[: FS * DURATION * nchunks], nchunks)
powers = np.array([power(x) for x in chunks])
clean = clean[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]

nchunks = len(noise) // (FS*DURATION)
chunks = np.split(noise[: FS * DURATION * nchunks], nchunks)
powers = np.array([power(x) for x in chunks])
noise = noise[np.argmax(powers) * FS * DURATION : (np.argmax(powers) + 1 ) *  FS * DURATION]

#handle silent noise
noise_nrgy = power(noise)
if noise_nrgy == 0.:
    #print('silent noise sample, using white noise')
    noise = np.random.randn( FS * DURATION )

# we set the SNR
ini_snr = 10 * np.log10(power(clean) / power(noise))
noise_gain_db = ini_snr - snrs[idx]
noise *= np.power(10, noise_gain_db/20)

# we normalize to 0.9 if mixture is close to clipping
clips = np.max(np.abs(clean + noise))
if clips >= 0.9:
    clips /= 0.9
    noise /= clips
    clean /= clips
# or to -18dBfs if smaller than that:
elif clips <= 10**(-18/20):
    clips /= 10**(-18/20)
    noise /= clips 
    clean /= clips    

# apply rir 
revspeech = sig.fftconvolve(clean, rir, 'full')
# synchronize reverberant with anechoic
lag = np.where(np.abs(rir) >= 0.5*np.max(np.abs(rir)))[0][0]

revspeech = revspeech[lag:FS*DURATION + lag]

# enforce energy conservation
revspeech *= np.sqrt(power(clean) / power(revspeech)) 

# apply RIR to noise too if needed
if reverberant_noises:
    rnoise = sig.fftconvolve(noise, rir, 'full')
    rnoise = rnoise[lag:FS*DURATION + lag]
    rnoise *= np.sqrt(power(noise) / power(rnoise))
    noise = rnoise
noisy = revspeech + noise

# check for Nans
if np.any(np.isnan(noisy)):
    print('noisy nan')
if np.any(np.isnan(clean)):
    print('clean nan')
noisy = torch.from_numpy(noisy)
clean = torch.from_numpy(clean)
meta = [speech_paths[idx], noise_paths[int(idx % len(noise_paths))], rir_paths[int(idx % len(rir_paths))], snrs[idx].item()]
return noisy.float(), clean.float(), meta