In [None]:
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
import scipy.fft
from scipy.signal import resample_poly
import matplotlib.pyplot as plt
from IPython.display import Audio

import tqdm
import time
import seaborn as sns
# import my modules (helpers.py where I stored all the functions):
import helpers as hlp
import importlib 
importlib.reload(hlp)
from clarity.evaluator.hasqi import hasqi_v2
from clarity.evaluator.haspi import haspi_v2 
from clarity.evaluator.mbstoi import mbstoi 
import torch
import sys
from multiprocessing import Pool
import copy
import soundfile as sf

In [None]:
def get_spl(x):
    # returns the equivalent spl value of our studio recordings
    return 20 * np.log10(np.sqrt(np.mean( x ** 2))) + 127.9

def norm_rms(x):
    # normalizes the signal so it has an RMS value of 1
    y = copy.deepcopy(x)
    return y / np.sqrt(np.mean(x ** 2))

def SISDR(s, s_hat):
    """Computes the Scale-Invariant SDR as in [1]_.
    References
    ----------
    .. [1] Le Roux, Jonathan, et al. "SDR–half-baked or well done?." ICASSP 2019-2019 IEEE International Conference on
    Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2019.
    Parameters:
        s: targets of any shape
        s_hat: corresponding estimates of any shape
    """
    s = torch.from_numpy(s)
    s_hat = torch.from_numpy(s_hat)
    s = s.view(-1)
    EPS = torch.finfo(s.dtype).eps
    s_hat = s_hat.view(-1)
    a = (torch.dot(s_hat, s) * s) / ((s ** 2).sum() + EPS)
    b = a - s_hat
    return 10*torch.log10(((a*a).sum()) / ((b*b).sum()+EPS))

In [None]:
'''DATAFRAME: we need to build a dataframe with the following columns:
ha_model -> the HA device model and receiver combination, e.g. GN5_L
scene -> {party, meeting, restaurant}
degree -> {0, 30}
processing -> {bypass, ha, dnn, dnn_causal}
meas_path -> absolute path to that measurement wav file
target_path -> absolute path to the target or reference (the anechoic binaural downmix of the situation)
''';

# Directory where all recordings used as reference are stored
data_ref_dir='/home/ubuntu/Data/ha_listening_situations/'
folders = os.listdir(data_ref_dir+'recordings/ku_recordings/')
recnames = ['_'.join(x.split('_')[1:]) for x in folders]
folders = {recnames[i]: folders[i] for i in range(len(folders))}

FS_DNN = 16000

# Names of hearing devices
ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p",
           "gn5_l","gn5_m","gn5_h","gn3_l","gn3_m","gn3_h"]
           
# Different listening scenes:
scenes=['party','restaurant','meeting']

degrees = ['0', '30']

processings = ['bypass', 'ha', 'dnn_normal_noncausal', 'dnn_normal_causal', 'dnn_mild_noncausal', 'dnn_mild_causal']

process_folders = {'dnn_normal_noncausal' : 'processed_m1_alldata_normal',
                  'dnn_normal_causal' : 'processed_m4_alldata_normal_causal',
                  'dnn_mild_noncausal' : 'processed_m3_alldata_mild',
                  'dnn_mild_causal' : 'processed_m5_alldata_mild_causal'}


df = []

for ha_model in ha_models:
    for scene in scenes:
        for degree in degrees:
            for processing in processings:
                # first let's listen to GN3_H
                if processing == 'ha':
                    # the plus recording from that device, "enabled"
                    meas_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_enabled_ku100']+'/' \
                    +folders[ha_model+'_enabled_ku100'] \
                    +'_plus_'+scene+'_'+degree+'deg.wav'
                    
                elif processing == 'bypass':
                    try:
                        # measurement is the bypass recording from the device processed by the dnn
                        meas_path = data_ref_dir+'recordings/ku_recordings/' \
                        +folders[ha_model+'_bypass_ku100']+'/' \
                        +folders[ha_model+'_bypass_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                    except:
                        # measurement is the bypassed recording from the device processed by the dnn
                        meas_path = data_ref_dir+'recordings/ku_recordings/' \
                        +folders[ha_model+'_bypassed_ku100']+'/' \
                        +folders[ha_model+'_bypassed_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                else :
                    try:
                        # measurement is the bypass recording from the device processed by the dnn
                        meas_path = data_ref_dir+ process_folders[processing] +'/ku_processed/'\
                        +folders[ha_model+'_bypass_ku100']+'/' \
                        +folders[ha_model+'_bypass_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                    except:
                        # measurement is the bypassed recording from the device processed by the dnn
                        meas_path = data_ref_dir+ process_folders[processing] +'/ku_processed/'\
                        +folders[ha_model+'_bypassed_ku100']+'/' \
                        +folders[ha_model+'_bypassed_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                target_path = '/home/ubuntu/Data/ha_listening_situations/SH_versions/normal/'+ \
                                'sharvard-target1-'+degree+'deg/'+scene+'_bin_deg'+degree+'_snr5_ane_48000hz.wav'
            
                df.append({'ha_model' : ha_model,
                     'scene' : scene,
                     'degree ': degree,
                     'processing' : processing,
                     'meas_path' : meas_path,
                     'target_path' : target_path})
                

df = pd.DataFrame.from_dict(df)

In [None]:
def compute_obj_measures(idx, row):
    # Function: Compute objective measures based on one row of a data frame 
    # ----- Input: -----
    # idx - index in the original data frame
    # row - row containing filenames of plus and minus recording
    # ----- Output: -----
    # tuple of objective measures: 
    # - hasqi_L: hearing aid speech quality for left ear
    # - hasqi_R: hearing aid speech quality for right ear
    # - hasqi: hearing aid speech quality for best ear 
    # - haspi_L: hearing aid speech perception index for left ear 
    # - haspi_R: hearing aid speech perception index for right ear
    # - haspi: hearing aid speech perception index for best ear
    # - sisdr_L: scale-invariant signal to distortion ratio for left ear
    # - sisdr_R: scale-invariant signal to distortion ratio for right ear
    # - sisdr: scale-invariant signal to distortion ratio for best ear 
    # - mbstoi_B: modified binaural speech to objective intelligibility metric
    # ----------------------------------------------------------------------
    
    # Load audio and resmple if needed
    meas, fs = sf.read(row['meas_path'])
    target, fs_tar = sf.read(row['target_path'])

    if fs!= FS_DNN :
        meas = resample_poly(meas.astype(np.float32), FS_DNN, fs)
        
    if fs_tar != FS_DNN :
        target = resample_poly(target.astype(np.float32), FS_DNN, fs_tar)
    
    # ---------- Compute clean reference signal for further methods ----------
    # Synchronize signals 
    meas, target, lag =hlp.synch_sigs(meas,target)
    print( "Processing row " +str(idx)+ ". lag is : "+str(lag))
    
    # Store original SPLs
    meas_spl_l = get_spl(meas[:,0])
    meas_spl_r = get_spl(meas[:,1])

    # match lengtha and signal energy, make sure they don't clip
    crop = int(np.min([len(meas), len(target)]))
    meas = meas[:crop, :]
    target = target[:crop, :]
    target *= np.sqrt((meas ** 2).sum() /((target ** 2).sum()))
    norm_fac = np.max((np.max(np.abs(meas)), np.max(np.abs(target))))
    meas /= norm_fac
    target /= norm_fac

    # ------------- Objective measure : SI-SDR -------------
    sisdr_L = SISDR(target[:,0], meas[:,0]).item()
    sisdr_R = SISDR(target[:,1], meas[:,1]).item()
    sisdr = max(sisdr_L, sisdr_R)

    # ------------- Objective measure : MBSTOI -------------
    mbstoi_B = mbstoi(
        left_ear_clean=target[:,0],
        right_ear_clean=target[:,1],
        left_ear_noisy=meas[:,0],
        right_ear_noisy=meas[:,1],
        fs_signal=FS_DNN,  # signal sample rate
        sample_rate=9000,  # operating sample rate
        fft_size_in_samples=64,
        n_third_octave_bands=5,
        centre_freq_first_third_octave_hz=500,
        dyn_range=60,
    )
    # ------------- Objective measures : HASQI and HASPI -------------
    hearing_loss = np.array([0, 0, 0, 0, 0, 0])
    equalisation_mode=1
    level1=70 # the calibrated level of the reference signal
    
    # compute spl for each channel of the measure
    meas_spl_l -= level1
    meas_spl_r -= level1
    gain_l = 10**(meas_spl_l / 20) # spl gain compared to the 70dB reference
    gain_r = 10**(meas_spl_r / 20)

    # normalize to RMS==1
    meas_l = norm_rms(meas[:,0])
    meas_r = norm_rms(meas[:,1])
    target_l = norm_rms(target[:, 0])
    target_r = norm_rms(target[:,1])
    
    # apply gain
    meas_l *= gain_l
    meas_r *= gain_r
    
    hasqi_L, _, _, _ = hasqi_v2(target_l, FS_DNN, meas_l, FS_DNN, hearing_loss, equalisation_mode, level1)
    hasqi_R, _, _, _ = hasqi_v2(target_r, FS_DNN, meas_r, FS_DNN, hearing_loss, equalisation_mode, level1)
    hasqi = np.max((hasqi_L, hasqi_R))

    haspi_L, _ = haspi_v2(target_l, FS_DNN, meas_l, FS_DNN, hearing_loss, level1)
    haspi_R, _ = haspi_v2(target_r, FS_DNN, meas_r, FS_DNN, hearing_loss, level1)
    haspi = np.max((haspi_L, haspi_R))

    # ------------- Magnitude square coherence -------------
    return idx, hasqi_L, hasqi_R, hasqi, haspi_L, haspi_R, haspi, sisdr_L, sisdr_R, sisdr, mbstoi_B

In [None]:
if __name__ == '__main__':
    NUM_OF_WORKERS = 8
    t0 = time.time()
    with Pool(NUM_OF_WORKERS) as pool:
        results = [pool.apply_async(compute_obj_measures, [idx, row]) for idx, row in df.iterrows()]
        for result in results:
            idx, hasqi_L, hasqi_R, hasqi, haspi_L, haspi_R, haspi, sisdr_L, sisdr_R, sisdr, mbstoi = result.get()
            df.loc[idx, 'haspi_L'] = haspi_L
            df.loc[idx, 'haspi_R'] = haspi_R
            df.loc[idx, 'haspi'] = haspi
            df.loc[idx, 'hasqi_L'] = hasqi_L
            df.loc[idx, 'hasqi_R'] = hasqi_R
            df.loc[idx, 'hasqi'] = hasqi
            df.loc[idx, 'sisdr_L'] = sisdr_L
            df.loc[idx, 'sisdr_R'] = sisdr_R
            df.loc[idx, 'sisdr'] = sisdr
            df.loc[idx, 'mbstoi'] = mbstoi
            
df.to_csv('results.csv')
print('Took '+ str(time.time()-t0)+' seconds.')

In [None]:
df = pd.read_csv('results.csv')

aux=df[df['processing']!='dnn_mild_noncausal']

aux=aux[aux['processing']!='dnn_mild_causal']

aux['dev_group']=aux['ha_model']
aux.loc[aux['dev_group'].str.contains('gn3'),'dev_group']='d1'
aux.loc[aux['dev_group'].str.contains('gn5'),'dev_group']='d2'
aux.loc[aux['dev_group'].str.contains('si3'),'dev_group']='d3'
aux.loc[aux['dev_group'].str.contains('ph5'),'dev_group']='d4'
aux.loc[aux['dev_group'].str.contains('ph4'),'dev_group']='d5'


# create a reference dataframe with repeated resulta for bypass recording 
bypass=aux[aux["processing"]=="bypass"]
bypass=pd.concat([bypass]*3)

# create a reference dataframe with results for all processing methods
ha=aux[aux["processing"]=="ha"]
dnn1=aux[aux["processing"]=="dnn_normal_causal"]
dnn2=aux[aux["processing"]=="dnn_normal_noncausal"]



processed = pd.concat([ha, dnn1], ignore_index=True)
processed = pd.concat([processed, dnn2], ignore_index=True)

# make

# sure the two dataframes that are going to be compared have values sorted in the same way
bypass=bypass.reset_index(drop=True)
processed=processed.reset_index(drop=True)
print(bypass['scene'].equals(processed['scene']))
print(bypass['ha_model'].equals(processed['ha_model']))
print(bypass['degree '].equals(processed['degree ']))

# dataframe with difference measures (benefit of each measure)
aux_delta=processed
aux_delta['haspi']=processed['haspi']-bypass['haspi']
aux_delta['hasqi']=processed['hasqi']-bypass['hasqi']
aux_delta['mbstoi']=processed['mbstoi']-bypass['mbstoi']
aux_delta['sisdr']=processed['sisdr']-bypass['sisdr']

aux_delta = aux_delta.replace('ha', 'HA')
aux_delta = aux_delta.replace('dnn_normal_noncausal', 'DNN')
aux_delta = aux_delta.replace('dnn_normal_causal', 'DNN-C')

In [None]:
print(len(aux_delta))
sns.set(rc={'axes.facecolor':'lightgrey', 'figure.facecolor':'none'})
fig, axes = plt.subplots(2, 2, figsize=(8, 5))
print(axes.shape)
sns.set(font_scale=0.7)
sns.violinplot(ax=axes[0,0], data=aux_delta, y='haspi', x='processing',palette="colorblind")
sns.violinplot(ax=axes[0,1],data=aux_delta, y='hasqi', x='processing',palette="colorblind")
sns.violinplot(ax=axes[1,0],data=aux_delta, y='sisdr', x='processing',palette="colorblind")
sns.violinplot(ax=axes[1,1],data=aux_delta, y='mbstoi', x='processing',palette="colorblind")
axes[0,0].set_title('$\Delta$HASPI', fontsize=13)
axes[0,1].set_title('$\Delta$HASQI', fontsize=13)
axes[1,0].set_title('$\Delta$SISDR', fontsize=13)
axes[1,1].set_title('$\Delta$MBSTOI', fontsize=13)
axes[0,0].set_xlabel('')
axes[0,1].set_xlabel('')
axes[1,0].set_xlabel('')
axes[1,1].set_xlabel('')
axes[0,0].set_ylabel('')
axes[0,1].set_ylabel('')
axes[1,0].set_ylabel('')
axes[1,1].set_ylabel('')
#axes[0,0].set_ylim((-0.2,0.4))
#axes[0,1].set_ylim((-0.2,0.3))
#axes[1,1].set_ylim((-0.2,0.3))
# axes[0,0].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[0,1].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[1,0].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[1,1].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
plt.tight_layout()
plt.savefig(os.path.join('figures', 'basic_plot.pdf'))
plt.show()

In [None]:
'''
aux = aux.replace('gn3_h', 'device 1')
aux = aux.replace('gn3_l', 'device 1')
aux = aux.replace('gn3_m', 'device 1')
aux = aux.replace('gn5_h', 'device 2')
aux = aux.replace('gn5_l', 'device 2')
aux = aux.replace('gn5_m', 'device 2')
aux = aux.replace('ph4_m', 'device 3')
aux = aux.replace('ph4_s', 'device 3')
aux = aux.replace('ph4_p', 'device 3')
aux = aux.replace('ph5_m', 'device 4')
aux = aux.replace('ph5_s', 'device 4')
aux = aux.replace('ph5_p', 'device 4')
aux = aux.replace('si3_m', 'device 5')
aux = aux.replace('si3_s', 'device 5')
aux = aux.replace('si3_p', 'device 5')
'''
aux = aux.replace('d1', 'device 5')
aux = aux.replace('d2', 'device 4')
aux = aux.replace('d3', 'device 1')
aux = aux.replace('d4', 'device 3')
aux = aux.replace('d5', 'device 2')

aux = aux.replace('bypass', 'Bypass')
aux = aux.replace('ha', 'HA')
aux = aux.replace('dnn_normal_noncausal', 'DNN')
aux = aux.replace('dnn_normal_causal', 'DNN-C')

In [None]:
sns.set(rc={'axes.facecolor':'lightgrey', 'figure.facecolor':'none'})
plt.figure(figsize=(5, 8))
plt.subplot(4,1,1)
sns.barplot(data=aux, y='sisdr', x='processing',palette="colorblind",hue='dev_group')
plt.ylabel('SISDR')
plt.xlabel('')
plt.legend().set_visible(False)
plt.tight_layout()
plt.legend(bbox_to_anchor=(0.50, 1.5), loc="upper center", ncol=3)
plt.ylim([-15, 0])
plt.yticks([-12, -9, -6, -3, 0])
plt.subplot(4,1,2)

sns.barplot(data=aux, y='haspi', x='processing',palette="colorblind",hue='dev_group')
plt.ylabel('HASPI')
plt.xlabel('')
plt.legend().set_visible(False)
plt.tight_layout()
plt.ylim([0.2, 0.9])
plt.yticks([0.2, 0.4, 0.6, 0.8, 1])

plt.subplot(4,1,3)
sns.barplot(data=aux, y='hasqi', x='processing',palette="colorblind",hue='dev_group')
plt.ylabel('HASQI')
plt.xlabel('')
plt.legend().set_visible(False)
plt.tight_layout()
plt.ylim([0.05, 0.23])
plt.yticks([0.05, 0.1, 0.15, 0.2, 0.25])


plt.subplot(4,1,4)
sns.barplot(data=aux, y='mbstoi', x='processing',palette="colorblind",hue='dev_group')
plt.ylabel('MBSTOI')
plt.xlabel('')
plt.legend().set_visible(False)
plt.tight_layout()
plt.ylim([0.3, 0.75]);
plt.yticks([0.3, 0.4, 0.5, 0.6, 0.7])
plt.savefig(os.path.join('figures', 'bar_plot.pdf'))
