In [13]:
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
import scipy.fft
from scipy.signal import resample_poly
import matplotlib.pyplot as plt
from IPython.display import Audio

import tqdm
import time
import seaborn as sns
# import my modules (helpers.py where I stored all the functions):
import helpers as hlp
import importlib 
importlib.reload(hlp)
from clarity.evaluator.hasqi import hasqi_v2
from clarity.evaluator.haspi import haspi_v2 
from clarity.evaluator.mbstoi import mbstoi 
import torch
import sys
from multiprocessing import Pool
import copy
import soundfile as sf

In [14]:
def norm_rms(x):
    y = copy.deepcopy(x)
    return y / np.sqrt(np.mean(x ** 2))

def SISDR(s, s_hat):
    """Computes the Scale-Invariant SDR as in [1]_.
    References
    ----------
    .. [1] Le Roux, Jonathan, et al. "SDRâ€“half-baked or well done?." ICASSP 2019-2019 IEEE International Conference on
    Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2019.
    Parameters:
        s: targets of any shape
        s_hat: corresponding estimates of any shape
    """
    s = torch.from_numpy(s)
    s_hat = torch.from_numpy(s_hat)
    s = s.view(-1)
    EPS = torch.finfo(s.dtype).eps
    s_hat = s_hat.view(-1)
    a = (torch.dot(s_hat, s) * s) / ((s ** 2).sum() + EPS)
    b = a - s_hat
    return 10*torch.log10(((a*a).sum()) / ((b*b).sum()+EPS))

def SNRLoss(target, input):
    EPS = torch.finfo(target.dtype).eps
    input_mean = torch.mean(input, dim=-1, keepdim=True)
    target_mean = torch.mean(target, dim=-1, keepdim=True)
    input = input - input_mean
    target = target - target_mean
    res = input - target
    losses = 10 * torch.log10(
        (target ** 2).sum(-1) / ((res ** 2).sum(-1) + EPS) + EPS
    )
    # apply reduction
    losses = losses.mean()
    return losses

In [15]:
# Directory where all recordings used as reference are stored
data_ref_dir='/home/ubuntu/Data/ha_listening_situations/'
folders = os.listdir(data_ref_dir+'recordings/ku_recordings/')
recnames = ['_'.join(x.split('_')[1:]) for x in folders]
folders = {recnames[i]: folders[i] for i in range(len(folders))}

FS_DNN = 16000

# Names of hearing devices
ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p",
           "gn5_l","gn5_m","gn5_h","gn3_l","gn3_m","gn3_h"]

# GN3 bad phase inversion unusable
# GN5 bad phase inversion unusable

# PH5 good
# PH4 good
# SI3 good

#ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p"]
           
# Different listening scenes:
scenes=['party','restaurant','meeting']

degrees = ['0', '30']

processings = ['bypass', 'ha', 'dnn_normal_noncausal', 'dnn_normal_causal', 'dnn_mild_noncausal', 'dnn_mild_causal']

process_folders = {'dnn_normal_noncausal' : 'processed_m1_alldata_normal',
                  'dnn_normal_causal' : 'processed_m4_alldata_normal_causal',
                  'dnn_mild_noncausal' : 'processed_m3_alldata_mild',
                  'dnn_mild_causal' : 'processed_m5_alldata_mild_causal'}


In [16]:
df = []

for ha_model in ha_models:
    for scene in scenes:
        for degree in degrees:
            for processing in processings:
                # first let's listen to GN3_H
                if processing == 'ha':
                    # the plus recording from that device, "enabled"
                    meas_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_enabled_ku100']+'/' \
                    +folders[ha_model+'_enabled_ku100'] \
                    +'_plus_'+scene+'_'+degree+'deg.wav'
                    
                elif processing == 'bypass':
                    try:
                        # measurement is the bypass recording from the device processed by the dnn
                        meas_path = data_ref_dir+'recordings/ku_recordings/' \
                        +folders[ha_model+'_bypass_ku100']+'/' \
                        +folders[ha_model+'_bypass_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                    except:
                        # measurement is the bypassed recording from the device processed by the dnn
                        meas_path = data_ref_dir+'recordings/ku_recordings/' \
                        +folders[ha_model+'_bypassed_ku100']+'/' \
                        +folders[ha_model+'_bypassed_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'

                    '''
                    # plus and minus recordings from same device, "bypass"
                    ref_plus_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_bypass_ku100']+'/' \
                    +folders[ha_model+'_bypass_ku100'] \
                    +'_plus_'+scene+'_'+degree+'deg.wav'

                    ref_minus_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_bypass_ku100']+'/' \
                    +folders[ha_model+'_bypass_ku100'] \
                    +'_minus_'+scene+'_'+degree+'deg.wav'
                    '''
                else :
                    try:
                        # measurement is the bypass recording from the device processed by the dnn
                        meas_path = data_ref_dir+ process_folders[processing] +'/ku_processed/'\
                        +folders[ha_model+'_bypass_ku100']+'/' \
                        +folders[ha_model+'_bypass_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                    except:
                        # measurement is the bypassed recording from the device processed by the dnn
                        meas_path = data_ref_dir+ process_folders[processing] +'/ku_processed/'\
                        +folders[ha_model+'_bypassed_ku100']+'/' \
                        +folders[ha_model+'_bypassed_ku100'] \
                        +'_plus_'+scene+'_'+degree+'deg.wav'
                    '''
                    # reference is the (unprocessed) bypass recordings of the same device
                    ref_plus_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_bypass_ku100']+'/' \
                    +folders[ha_model+'_bypass_ku100'] \
                    +'_plus_'+scene+'_'+degree+'deg.wav'

                    ref_minus_path = data_ref_dir+'recordings/ku_recordings/' \
                    +folders[ha_model+'_bypass_ku100']+'/' \
                    +folders[ha_model+'_bypass_ku100'] \
                    +'_minus_'+scene+'_'+degree+'deg.wav'
                    '''
                target_path = '/home/ubuntu/Data/ha_listening_situations/SH_versions/normal/'+ \
                                'sharvard-target1-'+degree+'deg/'+scene+'_bin_deg'+degree+'_snr5_ane_48000hz.wav'
            
                df.append({'ha_model' : ha_model,
                     'scene' : scene,
                     'degree ': degree,
                     'processing' : processing,
                     'meas_path' : meas_path,
                     'target_path' : target_path})
                

df = pd.DataFrame.from_dict(df)

In [17]:
def compute_obj_measures(idx, row):
    # Function: Compute objective measures based on one row of a data frame 
    # ----- Input: -----
    # ref_method - method for computing clean reference signal
    # idx - index in the original data frame
    # row - row containing filenames of plus and minus recording
    # ----- Output: -----
    # tuple of objective measures: 
    # - snr_val: snr estimated with phase inversion technique
    # - mbstoi_val: binaural speech intelligibility model
    # - hasqi_left_val: hearing aid speech quality for 1 ear
    # - hasqi_left_val: hearing aid speech perception index for 1 ear
    # - sisdr_val: sudo-rm-rf method for computing sdr
    # ----------------------------------------------------------------------
    # print sth when the row is done:    
    meas, fs = sf.read(row['meas_path'])
    target, fs_tar = sf.read(row['target_path'])

    if fs!= FS_DNN :
        meas = resample_poly(meas.astype(np.float32), FS_DNN, fs)
        
    if fs_tar != FS_DNN :
        target = resample_poly(target.astype(np.float32), FS_DNN, fs_tar)
    
    # ---------- Compute clean reference signal for further methods ----------
    # Synchronize signals 
    meas, target, lag =hlp.synch_sigs(meas,target)
    #print(lag)
    print( "Processing row " +str(idx)+ ". lag is : "+str(lag))

    crop = int(np.min([len(meas), len(target)]))
    meas = meas[:crop, :]
    target = target[:crop, :]
    target *= np.sqrt((meas ** 2).sum() /((target ** 2).sum()))
    norm_fac = np.max((np.max(np.abs(meas)), np.max(np.abs(target))))
    meas /= norm_fac
    target /= norm_fac

    # ------------- Objective measure : SI-SDR -------------
    sisdr_L = SISDR(target[:,0], meas[:,0]).item()
    sisdr_R = SISDR(target[:,1], meas[:,1]).item()
#    sisdr = max(sisdr_L, sisdr_R)
    sisdr = SISDR(target, meas).item()

    # ------------- Objective measure : SNR Losss -------------
    snrloss_L=SNRLoss(torch.from_numpy(target[:,0]),torch.from_numpy(meas[:,0])).item() 
    snrloss_R=SNRLoss(torch.from_numpy(target[:,1]),torch.from_numpy(meas[:,1])).item() 
    snrloss = np.mean((snrloss_L, snrloss_R))
    #snrloss = np.max((snrloss_L, snrloss_R))

    # ------------- Objective measure : MBSTOI -------------
    mbstoi_B = mbstoi(
        left_ear_clean=target[:,0],
        right_ear_clean=target[:,1],
        left_ear_noisy=meas[:,0],
        right_ear_noisy=meas[:,1],
        fs_signal=FS_DNN,  # signal sample rate
        sample_rate=9000,  # operating sample rate
        fft_size_in_samples=64,
        n_third_octave_bands=5,
        centre_freq_first_third_octave_hz=500,
        dyn_range=60,
    )
    # ------------- Objective measures : HASQI and HASPI -------------
    hearing_loss = np.array([0, 0, 0, 0, 0, 0])
    equalisation_mode=1
    level1=70
    sig_l_norm = np.sqrt(len(target[:,0]))
    #hasqi_L, _, _, _ = hasqi_v2(norm_rms(target[:,0]), FS_DNN, norm_rms(meas[:,0]), FS_DNN, hearing_loss, equalisation_mode, level1)
    hasqi_L, _, _, _ = hasqi_v2(norm_rms(target[:,0]), FS_DNN, meas[:,0], FS_DNN, hearing_loss, equalisation_mode, level1)
    #hasqi_R, _, _, _ = hasqi_v2(norm_rms(target[:,1]), FS_DNN, norm_rms(meas[:,1]), FS_DNN, hearing_loss, equalisation_mode, level1)
    hasqi_R, _, _, _ = hasqi_v2(norm_rms(target[:,1]), FS_DNN, meas[:,1], FS_DNN, hearing_loss, equalisation_mode, level1)
    hasqi = np.mean((hasqi_L, hasqi_R))
#    hasqi = np.max((hasqi_L, hasqi_R))

    #haspi_L, _ = haspi_v2(norm_rms(target[:,0]), FS_DNN, norm_rms(meas[:,0]) , FS_DNN, hearing_loss, level1)
    #haspi_R, _ = haspi_v2(norm_rms(target[:,1]), FS_DNN, norm_rms(meas[:,1]) , FS_DNN, hearing_loss, level1)
    haspi_L, _ = haspi_v2(norm_rms(target[:,0]), FS_DNN, meas[:,0], FS_DNN, hearing_loss, level1)
    haspi_R, _ = haspi_v2(norm_rms(target[:,1]), FS_DNN, meas[:,1], FS_DNN, hearing_loss, level1)
# without haspiFIX:
    #haspi_L, _ = haspi_v2(target[:,0], FS_DNN, hlp.add_signals(meas[:,0],target[:,0]) , FS_DNN, hearing_loss, level1)
    #haspi_R, _ = haspi_v2(target[:,1], FS_DNN, hlp.add_signals(meas[:,1],target[:,1]) , FS_DNN, hearing_loss, level1)
#    haspi = np.max((haspi_L, haspi_R))
    haspi = np.mean((haspi_L, haspi_R))

    # ------------- Magnitude square coherence -------------
    return idx, hasqi_L, hasqi_R, hasqi, haspi_L, haspi_R, haspi, sisdr_L, sisdr_R, sisdr, snrloss_L, snrloss_R, snrloss, mbstoi_B

In [18]:
if __name__ == '__main__':
    NUM_OF_WORKERS = 8
    t0 = time.time()
    with Pool(NUM_OF_WORKERS) as pool:
        results = [pool.apply_async(compute_obj_measures, [idx, row]) for idx, row in df.iterrows()]
        for result in results:
            idx, hasqi_L, hasqi_R, hasqi, haspi_L, haspi_R, haspi, sisdr_L, sisdr_R, sisdr, snrloss_L, snrloss_R, snrloss, mbstoi = result.get()
            df.loc[idx, 'haspi_L'] = haspi_L
            df.loc[idx, 'haspi_R'] = haspi_R
            df.loc[idx, 'haspi'] = haspi
            df.loc[idx, 'hasqi_L'] = hasqi_L
            df.loc[idx, 'hasqi_R'] = hasqi_R
            df.loc[idx, 'hasqi'] = hasqi
            df.loc[idx, 'sisdr_L'] = sisdr_L
            df.loc[idx, 'sisdr_R'] = sisdr_R
            df.loc[idx, 'sisdr'] = sisdr
            df.loc[idx, 'snrloss_L'] = snrloss_L
            df.loc[idx, 'snrloss_R'] = snrloss_R
            df.loc[idx, 'snrloss'] = snrloss
            df.loc[idx, 'mbstoi'] = mbstoi
            
df.to_csv('results_haspiFix_normRef.csv')
print('Took '+ str(time.time()-t0)+' seconds.')

Processing row 0. lag is : 3057
Processing row 5. lag is : 3042
Processing row 4. lag is : 3057
Processing row 1. lag is : 3057Processing row 3. lag is : 3058Processing row 2. lag is : 3057Processing row 7. lag is : 3059



Processing row 6. lag is : 3059


KeyboardInterrupt: 

In [None]:
#df = pd.read_csv('results_haspiFix_normRef.csv')
#df = pd.read_csv('results_haspiFix_normBoth_be.csv')
#df = pd.read_csv('results_haspiFix_normBoth.csv')
#df = pd.read_csv('results_haspiFix.csv')
#df = pd.read_csv('results.csv')
df = pd.read_csv('results_original.csv')

'''
1) results_original.csv -> the script as it is in the last commit. Does it match the plots in the paper?
2) results_haspiFix.csv -> using meas instead of meas+target as measurement in HASPI. How bad is it?
3) results_haspiFix_normRef -> same as (2) but normalizing TARGET (or reference) to amplitude 1 RMS
4) results_haspiFix_normBoth -> same as (3) but normalizing both TARGET and MEAS
5) results_haspiFix_normBoth_be -> same as (4) but using BEST EAR (max instead of mean) in all metrics
6) results_haspiFix_exclude -> same as (2) but excluding the BAD devices (the ones that don't plus-minus well)
    -> does SISDR on this last one match the paper?
''';

In [None]:
aux=df[df['processing']!='dnn_mild_noncausal']

aux=aux[aux['processing']!='dnn_mild_causal']

aux['dev_group']=aux['ha_model']
aux.loc[aux['dev_group'].str.contains('gn3'),'dev_group']='d1'
aux.loc[aux['dev_group'].str.contains('gn5'),'dev_group']='d2'
aux.loc[aux['dev_group'].str.contains('si3'),'dev_group']='d3'
aux.loc[aux['dev_group'].str.contains('ph5'),'dev_group']='d4'
aux.loc[aux['dev_group'].str.contains('ph4'),'dev_group']='d5'


# create a reference dataframe with repeated resulta for bypass recording 
bypass=aux[aux["processing"]=="bypass"]
bypass=pd.concat([bypass]*3)

# create a reference dataframe with results for all processing methods
ha=aux[aux["processing"]=="ha"]
dnn1=aux[aux["processing"]=="dnn_normal_causal"]
dnn2=aux[aux["processing"]=="dnn_normal_noncausal"]


In [None]:

processed = pd.concat([ha, dnn1], ignore_index=True)
processed = pd.concat([processed, dnn2], ignore_index=True)

# make

In [None]:
# sure the two dataframes that are going to be compared have values sorted in the same way
bypass=bypass.reset_index(drop=True)
processed=processed.reset_index(drop=True)
print(bypass['scene'].equals(processed['scene']))
print(bypass['ha_model'].equals(processed['ha_model']))
print(bypass['degree '].equals(processed['degree ']))

# dataframe with difference measures (benefit of each measure)
aux_delta=processed
aux_delta['haspi']=processed['haspi']-bypass['haspi']
aux_delta['hasqi']=processed['hasqi']-bypass['hasqi']
aux_delta['mbstoi']=processed['mbstoi']-bypass['mbstoi']
aux_delta['sisdr']=processed['sisdr']-bypass['sisdr']

print(len(aux_delta))
sns.set(rc={'axes.facecolor':'lightgrey', 'figure.facecolor':'none'})
fig, axes = plt.subplots(2, 2, figsize=(8, 5))
print(axes.shape)
sns.set(font_scale=0.7)
sns.violinplot(ax=axes[0,0], data=aux_delta, y='haspi', x='processing',palette="colorblind")
sns.violinplot(ax=axes[0,1],data=aux_delta, y='hasqi', x='processing',palette="colorblind")
sns.violinplot(ax=axes[1,0],data=aux_delta, y='sisdr', x='processing',palette="colorblind")
sns.violinplot(ax=axes[1,1],data=aux_delta, y='mbstoi', x='processing',palette="colorblind")
axes[0,0].set_ylabel('$\Delta$HASPI')
axes[0,1].set_ylabel('$\Delta$HASQI')
axes[1,0].set_ylabel('$\Delta$SISDR')
axes[1,1].set_ylabel('$\Delta$MBSTOI')
axes[0,0].set_xlabel('')
axes[0,1].set_xlabel('')
axes[1,0].set_xlabel('')
axes[1,1].set_xlabel('')
#axes[0,0].set_ylim((-0.2,0.4))
#axes[0,1].set_ylim((-0.2,0.3))
#axes[1,1].set_ylim((-0.2,0.3))
# axes[0,0].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[0,1].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[1,0].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
# axes[1,1].set_xticklabels(['HA','DNN','DNN-C'], rotation=0)
plt.tight_layout()
plt.savefig(os.path.join('figures', 'basic_plot.pdf'))
plt.show()



In [None]:
sns.set(rc={'axes.facecolor':'lightgrey', 'figure.facecolor':'none'})
fig, axes = plt.subplots(2, 2, figsize=(12, 4))
print(axes.shape)
sns.set(font_scale=0.7)
pal=sns.color_palette(palette="colorblind", n_colors=15)
sns.barplot(ax=axes[0,0],data=aux, y='haspi', x='processing',palette="colorblind",hue='dev_group')
sns.barplot(ax=axes[0,1],data=aux, y='hasqi', x='processing',palette="colorblind",hue='dev_group')
sns.barplot(ax=axes[1,0],data=aux, y='sisdr', x='processing',palette="colorblind",hue='dev_group')
sns.barplot(ax=axes[1,1],data=aux, y='mbstoi', x='processing',palette="colorblind",hue='dev_group')
# for i in ax.containers:
#     ax.bar_label(i,)
# ax = sns.barplot(ax=axes[0,1],data=aux, y='hasqi', x='processing',palette="muted",hue='dev_group')
# for i in ax.containers:
#     ax.bar_label(i,)
# ax = sns.barplot(ax=axes[1,0],data=aux, y='sisdr', x='processing',palette="muted",hue='dev_group')
# for i in ax.containers:
#     ax.bar_label(i,)
# ax = sns.barplot(ax=axes[1,1],data=aux, y='mbstoi', x='processing',palette="muted",hue='dev_group')
# for i in ax.containers:
#     ax.bar_label(i,)
axes[0,0].set_ylabel('HASPI')
axes[0,1].set_ylabel('HASQI')
axes[1,0].set_ylabel('SISDR')
axes[1,1].set_ylabel('MBSTOI')
axes[0,0].set_xlabel('')
axes[0,1].set_xlabel('')
axes[1,0].set_xlabel('')
axes[1,1].set_xlabel('')
#axes[0,0].set_ylim((0.7,1.1))
#axes[0,1].set_ylim((0,0.4))
#axes[1,1].set_ylim((0.3,0.8))
axes[0,0].set_xticklabels(['Bypass','HA','DNN','DNN-C'], rotation=0)
axes[0,1].set_xticklabels(['Bypass','HA','DNN','DNN-C'], rotation=0)
axes[1,0].set_xticklabels(['Bypass','HA','DNN','DNN-C'], rotation=0)
axes[1,1].set_xticklabels(['Bypass','HA','DNN','DNN-C'], rotation=0)
axes[0,0].get_legend().remove() 
axes[0,1].get_legend().remove() 
axes[1,0].get_legend().remove() 
axes[1,1].get_legend().remove() 
plt.tight_layout()
plt.savefig(os.path.join('figures', 'bar_plot.pdf'))
plt.show()