In [None]:
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
import scipy.fft
from scipy.signal import resample_poly
import matplotlib.pyplot as plt
from IPython.display import Audio


# import my modules (helpers.py where I stored all the functions):
import helpers as hlp
import importlib 
importlib.reload(hlp)
from clarity.evaluator.hasqi import hasqi_v2
from clarity.evaluator.haspi import haspi_v2 
from clarity.evaluator.mbstoi import mbstoi 
import torch
import sys
from multiprocessing import Pool


In [None]:
# All our directories in which we store recordings or dnn-processed recordings:
data_source_dirs=[
    '/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m1_alldata_normal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m3_alldata_mild/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m4_alldata_normal_causal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m5_alldata_mild_causal/ku_processed/'
]

# Directory where all recordings used as reference are stored
data_ref_dir='/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/'

# Names of hearing devices
ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p",
           "gn5_l","gn5_m","gn5_h","gn3_l","gn3_m","gn3_h"]

# Different listening scenes:
scenenames=['party_0deg','party_30deg','restaurant_0deg','restaurant_30deg','meeting_0deg','meeting_30deg']

# Initialize dataframe, in which each row will represent one pair of plus and minus recording
# and all the info and objective measures associated with that pair. 
measures=pd.DataFrame(columns=['device','dnn_applied','plus_file','minus_file', 'plus_ref_file','minus_ref_file', 'scene'])

dnntypes=[]
pluses=[]
minuses=[]
refpluses=[]
refminuses=[]
scenes=[]
devices=[]
for i, data_source_dir in enumerate(data_source_dirs):
    for item in sorted(os.listdir(data_source_dir)):
        if os.path.isdir(data_source_dir+item):
            recording_name=item
            # this line identifies the current device name:
            device_id=[model for model in ha_models if model in recording_name]
            if len(device_id)>0:
                # this line finds the corresponding bypass recording name:
                ref_recording_name = [dirname for dirname in os.listdir(data_source_dir)
                if os.path.isdir(data_source_dir+dirname) and device_id[0] in dirname and "bypass" in dirname][0]
            else:
                ref_recording_name=recording_name
            
            for scene in scenenames:
                # paths for recorded plus and minus signals (data_source_dir + recording_name)
                plusfilepath=data_source_dir+recording_name+'/'+recording_name+'_plus_'+scene+'.wav'
                minusfilepath=data_source_dir+recording_name+'/'+recording_name+'_minus_'+scene+'.wav'
                # paths for reference plus and minus signals (data_ref_dir + ref_recording_name)
                plusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_plus_'+scene+'.wav'
                minusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_minus_'+scene+'.wav'

                applieddnntype=i
                devices.append(recording_name)
                pluses.append(plusfilepath)
                minuses.append(minusfilepath)
                refpluses.append(plusrefpath)
                refminuses.append(minusrefpath)
                scenes.append(scene)
                dnntypes.append(applieddnntype)

measures['device']=devices            
measures['dnn_applied']=dnntypes
measures['plus_file']=pluses
measures['minus_file']=minuses
measures['plus_ref_file']=refpluses
measures['minus_ref_file']=refminuses
measures['scene']=scenes


# Make sure to remove the lines where the "enabled" recording is processed and lines where ku100 recording is processed 
# (we are only interested in dnn-processed bypass recordings):
print(f'before: {len(measures)=}')
measures=measures[~((measures["device"].str.contains("enabled")) & (measures["dnn_applied"]>0))]
measures=measures[~((measures["device"].str.contains("001_ku100")) & (measures["dnn_applied"]>0))]
measures=measures[~((measures["device"].str.contains("fulldenoising")) & (measures["dnn_applied"]>0))]
print(f'after: {len(measures)=}')

# Within the group that is not processed by any dnn model we have to distinguish 3 categories and give them labels: 
# ---> unprocessed reference: ku100 recordings without a hearing aid
measures.loc[measures["device"].str.contains("001_ku100"), "dnn_applied"]=0.1
# ---> unprocessed bypass: recordings with hearing aids in bypass
measures.loc[((measures["device"].str.contains("bypass")) & (measures["dnn_applied"]==0)), "dnn_applied"]=0.2
# ---> unprocessed enabled: recordings with hearing aids enabled
measures.loc[((measures["device"].str.contains("enabled")) & (measures["dnn_applied"]==0)), "dnn_applied"]=0.3
measures.loc[((measures["device"].str.contains("fulldenoising")) & (measures["dnn_applied"]==0)), "dnn_applied"]=0.3

measures.to_csv('meas.csv')


In [None]:
def SISDR(s, s_hat):
    """Computes the Scale-Invariant SDR as in [1]_.
    References
    ----------
    .. [1] Le Roux, Jonathan, et al. "SDR–half-baked or well done?." ICASSP 2019-2019 IEEE International Conference on
    Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2019.
    Parameters:
        s: list of targets of any shape
        s_hat: list of corresponding estimates of any shape
    """
    s = torch.stack(s).view(-1)
    EPS = torch.finfo(s.dtype).eps
    s_hat = torch.stack(s_hat).view(-1)
    a = (torch.dot(s_hat, s) * s) / ((s ** 2).sum() + EPS)
    b = a - s_hat
    return -10*torch.log10(((a*a).sum()) / ((b*b).sum()+EPS))

def compute_obj_measures(idx, row):
    # Function: Compute objective measures based on one row of a data frame 
    # ----- Input: -----
    # idx - index in the original data frame
    # row - row containing filenames of plus and minus recording
    # ----- Output: -----
    # tuple of objective measures: 
    # - snr_val: snr estimated with phase inversion technique
    # - mbstoi_val: binaural speech intelligibility model
    # - hasqi_left_val: hearing aid speech quality for 1 ear
    # - hasqi_left_val: hearing aid speech perception index for 1 ear
    # - sisdr_val: sudo-rm-rf method for computing sdr
    # ----------------------------------------------------------------------
    FS_TARGET=16e3
    fs, plus = wavfile.read(row['plus_file'])
    fs, minus = wavfile.read(row['minus_file'])
    if fs!=16000:
        plus = resample_poly(plus.astype(np.float32), FS_TARGET, fs)
        minus = resample_poly(minus.astype(np.float32), FS_TARGET, fs)

    fs, refplus = wavfile.read(row['plus_ref_file'])
    fs, refminus = wavfile.read(row['minus_ref_file'])
    if fs!=16000:
        refplus = resample_poly(refplus.astype(np.float32), FS_TARGET, fs)
        refminus = resample_poly(refminus.astype(np.float32), FS_TARGET, fs)
    # ---------- Compute clean reference signal ----------
    ref_s=0.5*(refplus+refminus)
    # ------------- Objective measure 1: SNR -------------
    s=0.5*(plus+minus)
    n=0.5*(plus-minus)
    snr_left_val = 10 * np.log10(hlp.power(s[:,0]) / hlp.power(n[:,0]))
    snr_right_val = 10 * np.log10(hlp.power(s[:,1]) / hlp.power(n[:,1]))
    snr_val =10 * np.log10(hlp.power(s) / hlp.power(n))
    # ------------- Objective measure 2: MBSTOI -------------
    mbstoi_val = mbstoi(
        left_ear_clean=ref_s[:,0],
        right_ear_clean=ref_s[:,1],
        left_ear_noisy=plus[:,0],
        right_ear_noisy=plus[:,1],
        fs_signal=FS_TARGET,  # signal sample rate
        sample_rate=9000,  # operating sample rate
        fft_size_in_samples=64,
        n_third_octave_bands=5,
        centre_freq_first_third_octave_hz=500,
        dyn_range=60,
    )
    # ------------- Objective measure 3 & 4: HASQI and HASPI -------------
    hearing_loss = np.array([0, 0, 0, 0, 0, 0])
    equalisation_mode=1
    level1=65
    hasqi_left_val, _, _, _ = hasqi_v2(ref_s[:,0], FS_TARGET, plus[:,0], FS_TARGET, hearing_loss, equalisation_mode, level1)
    haspi_left_val, _ = haspi_v2(ref_s[:,0], FS_TARGET, plus[:,0] + ref_s[:,0], FS_TARGET, hearing_loss, level1)

    # ------------- Objective measure 5: SI-SDR -------------
    sisdr_val=SISDR([torch.from_numpy(ref_s)],[torch.from_numpy(plus)]).item()
    print(str(idx)+ "row")
    return idx, snr_left_val,snr_right_val,snr_val, mbstoi_val, hasqi_left_val, haspi_left_val,sisdr_val 


# The main script has to be in the same cell as the definition of the function
if __name__ == '__main__':
    NUM_OF_WORKERS = 8
    with Pool(NUM_OF_WORKERS) as pool:
        results = [pool.apply_async(compute_obj_measures, [idx, row]) for idx, row in measures.iterrows()]
        for result in results:
            idx, snr_left, snr_right, snr, mbstoi_v, hasqi_left, haspi_left, sisdr = result.get()
            measures.loc[idx, 'snr_left'] = snr_left
            measures.loc[idx, 'snr_right'] = snr_right
            measures.loc[idx, 'snr'] = snr
            measures.loc[idx, 'mbstoi'] = mbstoi_v
            measures.loc[idx, 'hasqi_left'] = hasqi_left
            measures.loc[idx, 'haspi_left'] = haspi_left
            measures.loc[idx, 'sisdr'] = sisdr

measures.to_csv('objective_measures2.csv')



In [None]:
# Some important edits before plotting: 

# read dataframe 
measures_computed=pd.read_csv('objective_measures.csv')

# Make sure to remove the lines where the "enabled" recording is processed and lines where ku100 recording is processed 
# (we are only interested in dnn-processed bypass recordings):
print(f'before: {len(measures_computed)=}')
measures_computed=measures_computed[~((measures_computed["device"].str.contains("enabled")) & (measures_computed["dnn_applied"]>0))]
measures_computed=measures_computed[~((measures_computed["device"].str.contains("001_ku100")) & (measures_computed["dnn_applied"]>0))]
measures_computed=measures_computed[~((measures_computed["device"].str.contains("fulldenoising")) & (measures_computed["dnn_applied"]>0))]
print(f'after: {len(measures_computed)=}')

# Within the group that is not processed by any dnn model we have to distinguish 3 categories and give them labels: 
# ---> unprocessed reference: ku100 recordings without a hearing aid
measures_computed.loc[measures_computed["device"].str.contains("001_ku100"), "dnn_applied"]=0.1
# ---> unprocessed bypass: recordings with hearing aids in bypass
measures_computed.loc[((measures_computed["device"].str.contains("bypass")) & (measures_computed["dnn_applied"]==0)), "dnn_applied"]=0.2
# ---> unprocessed enabled: recordings with hearing aids enabled
measures_computed.loc[((measures_computed["device"].str.contains("enabled")) & (measures_computed["dnn_applied"]==0)), "dnn_applied"]=0.3
measures_computed.loc[((measures_computed["device"].str.contains("fulldenoising")) & (measures_computed["dnn_applied"]==0)), "dnn_applied"]=0.3

# For a better visibility, we add one column which specifies the hearing aid model (without division into different receivers)
ha_models=["gn3","gn5","ph5","ph4","si3","fulldenoising","001_ku"]
for modelname in ha_models:
    measures_computed.loc[measures_computed['device'].str.contains(modelname), "device_group"]=modelname


In [None]:
import seaborn as sns
import warnings
  
# Settings the warnings to be ignored
warnings.filterwarnings('ignore')

def plot_1_measure(df,measure):
    fig, axes = plt.subplots(1, 3, figsize=(25, 5), sharey=True) 
    sns.set(font_scale=0.8)
    sns.swarmplot(ax=axes[1],data=df,x="dnn_applied", y=measure, hue="device_group")
    sns.swarmplot(ax=axes[2],data=df,x="dnn_applied", y=measure, hue="scene")
    df.boxplot(ax=axes[0],column=measure,by="dnn_applied")
    axes[0].set_xlabel('')
    axes[1].set_xlabel('')
    axes[2].set_xlabel('')
    axes[0].set_title('Standard boxplot')
    axes[1].set_title('Colored by devices')
    axes[2].set_title('Colored by scene')
    axes[0].set_xticklabels(['ref','bypass','HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[1].set_xticklabels(['ref','bypass','HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[2].set_xticklabels(['ref','bypass','HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    fig.suptitle('Measure: '+ measure)
    plt.show()

# Plot all measures:
plot_1_measure(measures_computed,"snr_left")
plot_1_measure(measures_computed,"sisdr")
plot_1_measure(measures_computed,"hasqi_left")
plot_1_measure(measures_computed,"haspi_left")
plot_1_measure(measures_computed,"mbstoi")


In [None]:
# Anova analysis...
import statsmodels.api as sm
from statsmodels.formula.api import ols

# perform three-way anova: how snr_left depends on 3 categorical variables (processing type, scene, and device group)
model = ols("""snr_left ~ C(dnn_applied) + C(scene) + C(device_group) +
               C(dnn_applied):C(scene) + C(dnn_applied):C(device_group) + C(scene):C(device_group) +
               C(dnn_applied):C(scene):C(device_group)""", data=measures_computed).fit()

sm.stats.anova_lm(model, typ=2)

# TODO: Why isn't dnn_applied significant??? 
# TODO: bypass (plus/minus) vs processed (plus)


