In [None]:
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
import scipy.fft
import scipy
from scipy.signal import resample_poly
import matplotlib.pyplot as plt
from IPython.display import Audio


# import my modules (helpers.py where I stored all the functions):
import helpers as hlp
import importlib 
importlib.reload(hlp)
from clarity.evaluator.hasqi import hasqi_v2
from clarity.evaluator.haspi import haspi_v2 
from clarity.evaluator.mbstoi import mbstoi 
import torch
import sys
from multiprocessing import Pool


In [None]:
# All our directories in which we store recordings or dnn-processed recordings:
data_source_dirs=[
    '/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m1_alldata_normal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m3_alldata_mild/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m4_alldata_normal_causal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m5_alldata_mild_causal/ku_processed/'
]

# Directory where all recordings used as reference are stored
data_ref_dir='/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/'

# Names of hearing devices
ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p",
           "gn5_l","gn5_m","gn5_h","gn3_l","gn3_m","gn3_h"]

# Different listening scenes:
scenenames=['party_0deg','party_30deg','restaurant_0deg','restaurant_30deg','meeting_0deg','meeting_30deg']

# Initialize dataframe, in which each row will represent one pair of plus and minus recording
# and all the info and objective measures associated with that pair. 
recordings_df=pd.DataFrame(columns=['device','dnn_applied','plus_file','minus_file', 'plus_ref_file','minus_ref_file', 'scene'])

# initialize lists to be filled inside the loop
dnntypes=[]
pluses=[]
minuses=[]
bypass_pluses=[]
bypass_minuses=[]
rawku_pluses=[]
rawku_minuses=[]
scenes=[]
devices=[]

for i, data_source_dir in enumerate(data_source_dirs):
    for item in sorted(os.listdir(data_source_dir)):
        if os.path.isdir(data_source_dir+item):
            recording_name=item
            # this line identifies the current device name:
            device_id=[model for model in ha_models if model in recording_name]
            if len(device_id)>0:
                # this line finds the corresponding bypass recording name:
                ref_recording_name = [dirname for dirname in os.listdir(data_source_dir)
                if os.path.isdir(data_source_dir+dirname) and device_id[0] in dirname and "bypass" in dirname][0]
            else:
                ref_recording_name=recording_name
            
            for scene in scenenames:
                # paths for recorded plus and minus signals (data_source_dir + recording_name)
                plusfilepath=data_source_dir+recording_name+'/'+recording_name+'_plus_'+scene+'.wav'
                minusfilepath=data_source_dir+recording_name+'/'+recording_name+'_minus_'+scene+'.wav'
                # paths for corresponding bypass plus and minus signals (data_ref_dir + ref_recording_name)
                bypass_plusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_plus_'+scene+'.wav'
                bypass_minusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_minus_'+scene+'.wav'
                # paths for corresponding ku100 plus and minus signals 
                ku_plusrefpath=data_ref_dir+'001_ku100/001_ku100_plus_'+scene+'.wav'
                ku_minusrefpath=data_ref_dir+'001_ku100/001_ku100_minus_'+scene+'.wav'
                # append all lists:
                applieddnntype=i
                devices.append(recording_name)
                pluses.append(plusfilepath)
                minuses.append(minusfilepath)
                bypass_pluses.append(bypass_plusrefpath)
                bypass_minuses.append(bypass_minusrefpath)
                rawku_pluses.append(ku_plusrefpath)
                rawku_minuses.append(ku_minusrefpath)
                scenes.append(scene)
                dnntypes.append(applieddnntype)

# fill the data frame with lists appended in the loop above:
recordings_df['device']=devices            
recordings_df['dnn_applied']=dnntypes
recordings_df['plus_file']=pluses
recordings_df['minus_file']=minuses
recordings_df['plus_ref_bypass']=bypass_pluses
recordings_df['minus_ref_bypass']=bypass_minuses
recordings_df['plus_ref_ku']=rawku_pluses
recordings_df['minus_ref_ku']=rawku_minuses
recordings_df['scene']=scenes

# Make sure to remove the lines where the "enabled" recording is processed and lines where ku100 recording is processed 
# (we are only interested in dnn-processed bypass recordings):
print(f'before: {len(recordings_df)=}')
recordings_df=recordings_df[~((recordings_df["device"].str.contains("enabled")) & (recordings_df["dnn_applied"]>0))]
recordings_df=recordings_df[~((recordings_df["device"].str.contains("001_ku100")) & (recordings_df["dnn_applied"]>0))]
recordings_df=recordings_df[~(recordings_df["device"].str.contains("fulldenoising"))]
print(f'after: {len(recordings_df)=}')

# Within the group that is not processed by any dnn model we have to distinguish 3 categories and give them labels: 
# ---> unprocessed reference: ku100 recordings without a hearing aid
recordings_df.loc[recordings_df["device"].str.contains("001_ku100"), "dnn_applied"]=0.1
# ---> unprocessed bypass: recordings with hearing aids in bypass
recordings_df.loc[((recordings_df["device"].str.contains("bypass")) & (recordings_df["dnn_applied"]==0)), "dnn_applied"]=0.2
# ---> unprocessed enabled: recordings with hearing aids enabled
recordings_df.loc[((recordings_df["device"].str.contains("enabled")) & (recordings_df["dnn_applied"]==0)), "dnn_applied"]=0.3

# Check if numbers match: 
print(f'Number of dnn-processed recordings should be {15*4*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]>=1)]))


print(f'Number of enabled recordings should be {15*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.3)])) 


print(f'Number of bypass recordings should be {15*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.2)])) 


print(f'Number of ku100 recordings should be {6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.1)])) 

# Keep only recordings processed with HA or DNN
recordings_only_processed=recordings_df[(recordings_df["dnn_applied"]>0.2)]
print(f'Number of all processed recordings should be {15*5*6}')
print(len(recordings_only_processed))

recordings_only_processed.to_csv('recordings_processed.csv')




In [None]:
def SISDR(s, s_hat):
    """Computes the Scale-Invariant SDR as in [1]_.
    References
    ----------
    .. [1] Le Roux, Jonathan, et al. "SDR–half-baked or well done?." ICASSP 2019-2019 IEEE International Conference on
    Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2019.
    Parameters:
        s: list of targets of any shape
        s_hat: list of corresponding estimates of any shape
    """
    s = torch.stack(s).view(-1)
    EPS = torch.finfo(s.dtype).eps
    s_hat = torch.stack(s_hat).view(-1)
    a = (torch.dot(s_hat, s) * s) / ((s ** 2).sum() + EPS)
    b = a - s_hat
    return 10*torch.log10(((a*a).sum()) / ((b*b).sum()+EPS))

def SNRLoss(target, input):
    EPS = torch.finfo(target.dtype).eps
    input_mean = torch.mean(input, dim=-1, keepdim=True)
    target_mean = torch.mean(target, dim=-1, keepdim=True)
    input = input - input_mean
    target = target - target_mean
    res = input - target
    losses = 10 * torch.log10(
        (target ** 2).sum(-1) / ((res ** 2).sum(-1) + EPS) + EPS
    )
    # apply reduction
    losses = losses.mean()
    return losses

def compute_obj_measures(idx, row, ref_method):
    # Function: Compute objective measures based on one row of a data frame 
    # ----- Input: -----
    # ref_method - method for computing clean reference signal
    # idx - index in the original data frame
    # row - row containing filenames of plus and minus recording
    # ----- Output: -----
    # tuple of objective measures: 
    # - snr_val: snr estimated with phase inversion technique
    # - mbstoi_val: binaural speech intelligibility model
    # - hasqi_left_val: hearing aid speech quality for 1 ear
    # - hasqi_left_val: hearing aid speech perception index for 1 ear
    # - sisdr_val: sudo-rm-rf method for computing sdr
    # ----------------------------------------------------------------------
    # print sth when the row is done:
    print( "Processing row " +str(idx)+ ".")

    FS_TARGET=16e3
    fs, plus = wavfile.read(row['plus_file'])
    fs, minus = wavfile.read(row['minus_file'])
    if fs!=16000:
        plus = resample_poly(plus.astype(np.float32), FS_TARGET, fs)
        minus = resample_poly(minus.astype(np.float32), FS_TARGET, fs)

    
    # ------------- Objective measure : SNR -------------
    # plus, minus =hlp.synch_sigs(plus,minus)
    s=0.5*(plus+minus)
    n=0.5*(plus-minus)
    snr_L= 10 * np.log10(hlp.power(s[:,0]) / hlp.power(n[:,0]))
    snr_R = 10 * np.log10(hlp.power(s[:,1]) / hlp.power(n[:,1]))

    # ---------- Compute clean reference signal for further methods ----------

    
    if ref_method=="it1":
    # Referencing_method 1 ----> our first iteration
        clean_ref=s
    
    elif ref_method=="it2":
    # Referencing_method 2 ----> our second iteration 
        fs, refplus = wavfile.read(row['plus_ref_bypass'])
        fs, refminus = wavfile.read(row['minus_ref_bypass'])
        if fs!=16000:
            refplus = resample_poly(refplus.astype(np.float32), FS_TARGET, fs)
            refminus = resample_poly(refminus.astype(np.float32), FS_TARGET, fs)
        # refplus, refminus =hlp.synch_sigs(refplus,refminus)
        clean_ref=0.5*(refplus+refminus)

    elif ref_method=="it3":
    # Referencing_method 3 ----> our third iteration 
        fs, refplus = wavfile.read(row['plus_ref_ku'])
        fs, refminus = wavfile.read(row['minus_ref_ku'])
        if fs!=16000:
            refplus = resample_poly(refplus.astype(np.float32), FS_TARGET, fs)
            refminus = resample_poly(refminus.astype(np.float32), FS_TARGET, fs)
        # refplus, refminus =hlp.synch_sigs(refplus,refminus)
        clean_ref=0.5*(refplus+refminus)


    # Synchronize signals 
    plus, clean_ref =hlp.synch_sigs(plus,clean_ref)
    
    # Equalize energy 
    nrgy_1 = (plus ** 2).sum()
    clean_ref *= np.sqrt(nrgy_1 /((clean_ref**2).sum()))

    # Save signals
    sig2save=np.concatenate((plus,clean_ref), axis=0)
    name2save=os.path.basename(row['plus_file'])[0:-4]
    scipy.io.wavfile.write("/home/ubuntu/Data/ha_listening_situations/debug/"+ref_method+"_"+name2save+".wav", int(FS_TARGET), sig2save)

    # ------------- Objective measure : SI-SDR -------------
    sisdr_L=SISDR([torch.from_numpy(clean_ref[:,0])],[torch.from_numpy(plus[:,0])]).item()
    sisdr_R=SISDR([torch.from_numpy(clean_ref[:,1])],[torch.from_numpy(plus[:,1])]).item()
    # ------------- Objective measure : SNR Losss -------------
    snrloss_L=SNRLoss(torch.from_numpy(clean_ref[:,0]),torch.from_numpy(plus[:,0])).item() 
    snrloss_R=SNRLoss(torch.from_numpy(clean_ref[:,1]),torch.from_numpy(plus[:,1])).item() 

    # ------------- Objective measure : MBSTOI -------------
    mbstoi_B = mbstoi(
        left_ear_clean=clean_ref[:,0],
        right_ear_clean=clean_ref[:,1],
        left_ear_noisy=plus[:,0],
        right_ear_noisy=plus[:,1],
        fs_signal=FS_TARGET,  # signal sample rate
        sample_rate=9000,  # operating sample rate
        fft_size_in_samples=64,
        n_third_octave_bands=5,
        centre_freq_first_third_octave_hz=500,
        dyn_range=60,
    )
    # ------------- Objective measures : HASQI and HASPI -------------
    hearing_loss = np.array([0, 0, 0, 0, 0, 0])
    equalisation_mode=1
    level1=70
    hasqi_L, _, _, _ = hasqi_v2(clean_ref[:,0], FS_TARGET, plus[:,0], FS_TARGET, hearing_loss, equalisation_mode, level1)
    hasqi_R, _, _, _ = hasqi_v2(clean_ref[:,1], FS_TARGET, plus[:,1], FS_TARGET, hearing_loss, equalisation_mode, level1)
    if plus[:,0].shape!=clean_ref[:,0].shape :
        print(f'Plus and clean have different shapes for idx = {idx}')
    haspi_L, _ = haspi_v2(clean_ref[:,0], FS_TARGET, hlp.add_signals(plus[:,0],clean_ref[:,0]) , FS_TARGET, hearing_loss, level1)
    haspi_R, _ = haspi_v2(clean_ref[:,1], FS_TARGET, hlp.add_signals(plus[:,1],clean_ref[:,1]) , FS_TARGET, hearing_loss, level1)

    return idx, snr_L, snr_R, hasqi_L, hasqi_R, haspi_L, haspi_R, sisdr_L, sisdr_R, snrloss_L, snrloss_R, mbstoi_B 

# Test function compute_obj_measures():
measures=pd.read_csv('recordings_processed.csv')

# idx, snr_L, snr_R, hasqi_L, hasqi_R, haspi_L, haspi_R, sisdr_L, sisdr_R, snrloss_L, snrloss_R, mbstoi_B = compute_obj_measures(120,measures.loc[120],"it1")
# print(f'Measures for {measures.loc[idx,"plus_file"]}\n{snr_L=}\n{hasqi_L=}\n{haspi_L=}\n{sisdr_L=}\n{snrloss_L=}\n{mbstoi_B=}')

# Multiprocessing - compute measures
# The main script has to be in the same cell as the definition of the function
if __name__ == '__main__':
    NUM_OF_WORKERS = 8
    with Pool(NUM_OF_WORKERS) as pool:
        results = [pool.apply_async(compute_obj_measures, [idx, row, "it1"]) for idx, row in measures.iterrows()]
        for result in results:
            idx, snr_L, snr_R, hasqi_L, hasqi_R, haspi_L, haspi_R, sisdr_L, sisdr_R, snrloss_L, snrloss_R, mbstoi_B  = result.get()
            measures.loc[idx, 'snr_L'] = snr_L
            measures.loc[idx, 'snr_R'] = snr_R
            measures.loc[idx, 'haspi_L'] = haspi_L
            measures.loc[idx, 'haspi_R'] = haspi_R
            measures.loc[idx, 'hasqi_L'] = hasqi_L
            measures.loc[idx, 'hasqi_R'] = hasqi_R
            measures.loc[idx, 'sisdr_L'] = sisdr_L
            measures.loc[idx, 'sisdr_R'] = sisdr_R
            measures.loc[idx, 'snrloss_L'] = snrloss_L
            measures.loc[idx, 'snrloss_R'] = snrloss_R
            measures.loc[idx, 'mbstoi_B'] = mbstoi_B

measures.to_csv('bla.csv')



In [None]:
np.log10(1)

In [None]:
measures_computed=pd.read_csv('objective_measures2.csv')

# For a better visibility, we add one column which specifies the hearing aid model (without division into different receivers)
ha_models=["gn3","gn5","ph5","ph4","si3"]
for modelname in ha_models:
    measures_computed.loc[measures_computed['device'].str.contains(modelname), "device_group"]=modelname

import seaborn as sns
import warnings
  
# Settings the warnings to be ignored
warnings.filterwarnings('ignore')

def plot_1_measure(df,measure):
    fig, axes = plt.subplots(1, 3, figsize=(25, 5), sharey=True) 
    sns.set(font_scale=0.8)
    sns.stripplot(ax=axes[1],data=df,x="dnn_applied", y=measure, hue="device_group")
    sns.stripplot(ax=axes[2],data=df,x="dnn_applied", y=measure, hue="scene")
    df.boxplot(ax=axes[0],column=measure,by="dnn_applied")
    axes[0].set_xlabel('')
    axes[1].set_xlabel('')
    axes[2].set_xlabel('')
    axes[0].set_title('Standard boxplot')
    axes[1].set_title('Colored by devices')
    axes[2].set_title('Colored by scene')
    axes[0].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[1].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[2].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    fig.suptitle('Measure: '+ measure)
    plt.show()

# Plot all measures:
plot_1_measure(measures_computed,"snrloss_L")
plot_1_measure(measures_computed,"snrloss_R")
plot_1_measure(measures_computed,"sisdr_L")
plot_1_measure(measures_computed,"sisdr_R")

In [None]:
plot_1_measure(measures_computed,"sisdr_L")
plot_1_measure(measures_computed,"sisdr_R")
plot_1_measure(measures_computed,"snrloss_L")

plot_1_measure(measures_computed,"snrloss_R")
#plot_1_measure(measures_computed,"hasqi_R")
#plot_1_measure(measures_computed,"haspi_R")
#plot_1_measure(measures_computed,"mbstoi_B")