In [None]:
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
import scipy.fft
import scipy
from scipy import signal
from scipy.signal import resample_poly
import matplotlib.pyplot as plt
from IPython.display import Audio

import tqdm
# import my modules (helpers.py where I stored all the functions):
import helpers as hlp
import importlib 
importlib.reload(hlp)
from clarity.evaluator.hasqi import hasqi_v2
from clarity.evaluator.haspi import haspi_v2 
from clarity.evaluator.mbstoi import mbstoi 
import torch
import torch.nn as nn
import sys
from multiprocessing import Pool

from noresqa import NORESQA

In [None]:
metric_type = 1 #NORESQA-MOS

In [None]:
model_checkpoint_path = 'noresqa_models/model_noresqa_mos.pth'
state = torch.load(model_checkpoint_path,map_location="cpu")['state_dict']

model = NORESQA(output=40, output2=40, metric_type = metric_type, config_path = 'noresqa_models/wav2vec_small.pt')


In [None]:
pretrained_dict = {}
for k, v in state.items():
    if 'module' in k:
        pretrained_dict[k.replace('module.','')]=v
    else:
        pretrained_dict[k]=v
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
model.load_state_dict(pretrained_dict)

In [None]:
# change device as needed
# device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model.to(device)
model.eval();

In [None]:
sfmax = nn.Softmax(dim=1)
# function extraction stft
def extract_stft(audio, sampling_rate = 16000):

    fx, tx, stft_out = signal.stft(audio, sampling_rate, window='hann',nperseg=512,noverlap=256,nfft=512)
    stft_out = stft_out[:256,:]
    feat = np.concatenate((np.abs(stft_out).reshape([stft_out.shape[0],stft_out.shape[1],1]), np.angle(stft_out).reshape([stft_out.shape[0],stft_out.shape[1],1])), axis=2)
    return feat

# noresqa and noresqa-mos prediction calls
def model_prediction_noresqa(test_feat, nmr_feat):

    intervals_sdr = np.arange(0.5,40,1)

    with torch.no_grad():
        ranking_frame,sdr_frame,snr_frame = model(test_feat.permute(0,3,2,1),nmr_feat.permute(0,3,2,1))
        # preference task prediction
        ranking = sfmax(ranking_frame).mean(2).detach().cpu().numpy()
        pout = ranking[0][0]
        # quantification task
        sdr = intervals_sdr * (sfmax(sdr_frame).mean(2).detach().cpu().numpy())
        qout = sdr.sum()

    return pout, qout

def model_prediction_noresqa_mos(test_feat, nmr_feat):

    with torch.no_grad():
        score = model(nmr_feat,test_feat).detach().cpu().numpy()[0]

    return score

# reading audio clips
def audio_loading(path,sampling_rate=16000):

    audio, fs = librosa.load(path, sr=None)
    if len(audio.shape) > 1:
        audio = librosa.to_mono(audio)

    if fs != sampling_rate:
        audio = librosa.resample(audio,fs,sampling_rate)

    return audio


# function checking if the size of the inputs are same. If not, then the reference audio's size is adjusted
def check_size(audio_ref,audio_test):

    if len(audio_ref) > len(audio_test):
        print('Durations dont match. Adjusting duration of reference.')
        audio_ref = audio_ref[:len(audio_test)]

    elif len(audio_ref) < len(audio_test):
        print('Durations dont match. Adjusting duration of reference.')
        while len(audio_test) > len(audio_ref):
            audio_ref = np.append(audio_ref, audio_ref)
        audio_ref = audio_ref[:len(audio_test)]

    return audio_ref, audio_test


# audio loading and feature extraction
def feats_loading(test_path, ref_path=None, noresqa_or_noresqaMOS = 0):

    if noresqa_or_noresqaMOS == 0 or noresqa_or_noresqaMOS == 1:

        audio_ref = audio_loading(ref_path)
        audio_test = audio_loading(test_path)
        audio_ref, audio_test = check_size(audio_ref,audio_test)

        if noresqa_or_noresqaMOS == 0:
            ref_feat = extract_stft(audio_ref)
            test_feat = extract_stft(audio_test)
            return ref_feat,test_feat
        else:
            return audio_ref, audio_test


In [None]:
# All our directories in which we store recordings or dnn-processed recordings:
data_source_dirs=[
    '/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m1_alldata_normal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m3_alldata_mild/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m4_alldata_normal_causal/ku_processed/',
    '/home/ubuntu/Data/ha_listening_situations/processed_m5_alldata_mild_causal/ku_processed/'
]

# Directory where all recordings used as reference are stored
data_ref_dir='/home/ubuntu/Data/ha_listening_situations/recordings/ku_recordings/'

# Names of hearing devices
ha_models=["si3_s","si3_m","si3_p","ph4_s","ph4_m","ph4_p","ph5_s","ph5_m","ph5_p",
           "gn5_l","gn5_m","gn5_h","gn3_l","gn3_m","gn3_h"]

# Different listening scenes:
scenenames=['party_0deg','party_30deg','restaurant_0deg','restaurant_30deg','meeting_0deg','meeting_30deg']

# Initialize dataframe, in which each row will represent one pair of plus and minus recording
# and all the info and objective measures associated with that pair. 
recordings_df=pd.DataFrame(columns=['device','dnn_applied','plus_file','minus_file', 'plus_ref_file','minus_ref_file', 'scene'])

# initialize lists to be filled inside the loop
dnntypes=[]
pluses=[]
minuses=[]
bypass_pluses=[]
bypass_minuses=[]
rawku_pluses=[]
rawku_minuses=[]
scenes=[]
devices=[]

for i, data_source_dir in enumerate(data_source_dirs):
    for item in sorted(os.listdir(data_source_dir)):
        if os.path.isdir(data_source_dir+item):
            recording_name=item
            # this line identifies the current device name:
            device_id=[model for model in ha_models if model in recording_name]
            if len(device_id)>0:
                # this line finds the corresponding bypass recording name:
                ref_recording_name = [dirname for dirname in os.listdir(data_source_dir)
                if os.path.isdir(data_source_dir+dirname) and device_id[0] in dirname and "bypass" in dirname][0]
            else:
                ref_recording_name=recording_name
            
            for scene in scenenames:
                # paths for recorded plus and minus signals (data_source_dir + recording_name)
                plusfilepath=data_source_dir+recording_name+'/'+recording_name+'_plus_'+scene+'.wav'
                minusfilepath=data_source_dir+recording_name+'/'+recording_name+'_minus_'+scene+'.wav'
                # paths for corresponding bypass plus and minus signals (data_ref_dir + ref_recording_name)
                bypass_plusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_plus_'+scene+'.wav'
                bypass_minusrefpath=data_ref_dir+ref_recording_name+'/'+ref_recording_name+'_minus_'+scene+'.wav'
                # paths for corresponding ku100 plus and minus signals 
                ku_plusrefpath=data_ref_dir+'001_ku100/001_ku100_plus_'+scene+'.wav'
                ku_minusrefpath=data_ref_dir+'001_ku100/001_ku100_minus_'+scene+'.wav'
                # append all lists:
                applieddnntype=i
                devices.append(recording_name)
                pluses.append(plusfilepath)
                minuses.append(minusfilepath)
                bypass_pluses.append(bypass_plusrefpath)
                bypass_minuses.append(bypass_minusrefpath)
                rawku_pluses.append(ku_plusrefpath)
                rawku_minuses.append(ku_minusrefpath)
                scenes.append(scene)
                dnntypes.append(applieddnntype)

# fill the data frame with lists appended in the loop above:
recordings_df['device']=devices            
recordings_df['dnn_applied']=dnntypes
recordings_df['plus_file']=pluses
recordings_df['minus_file']=minuses
recordings_df['plus_ref_bypass']=bypass_pluses
recordings_df['minus_ref_bypass']=bypass_minuses
recordings_df['plus_ref_ku']=rawku_pluses
recordings_df['minus_ref_ku']=rawku_minuses
recordings_df['scene']=scenes

# Make sure to remove the lines where the "enabled" recording is processed and lines where ku100 recording is processed 
# (we are only interested in dnn-processed bypass recordings):
print(f'before: {len(recordings_df)=}')
recordings_df=recordings_df[~((recordings_df["device"].str.contains("enabled")) & (recordings_df["dnn_applied"]>0))]
recordings_df=recordings_df[~((recordings_df["device"].str.contains("001_ku100")) & (recordings_df["dnn_applied"]>0))]
recordings_df=recordings_df[~(recordings_df["device"].str.contains("fulldenoising"))]
print(f'after: {len(recordings_df)=}')

# Within the group that is not processed by any dnn model we have to distinguish 3 categories and give them labels: 
# ---> unprocessed reference: ku100 recordings without a hearing aid
recordings_df.loc[recordings_df["device"].str.contains("001_ku100"), "dnn_applied"]=0.1
# ---> unprocessed bypass: recordings with hearing aids in bypass
recordings_df.loc[((recordings_df["device"].str.contains("bypass")) & (recordings_df["dnn_applied"]==0)), "dnn_applied"]=0.2
# ---> unprocessed enabled: recordings with hearing aids enabled
recordings_df.loc[((recordings_df["device"].str.contains("enabled")) & (recordings_df["dnn_applied"]==0)), "dnn_applied"]=0.3

# Check if numbers match: 
print(f'Number of dnn-processed recordings should be {15*4*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]>=1)]))


print(f'Number of enabled recordings should be {15*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.3)])) 


print(f'Number of bypass recordings should be {15*6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.2)])) 


print(f'Number of ku100 recordings should be {6}')
print(len(recordings_df.loc[(recordings_df["dnn_applied"]==0.1)])) 

# Keep only recordings processed with HA or DNN
recordings_only_processed=recordings_df[(recordings_df["dnn_applied"]>0.2)]
print(f'Number of all processed recordings should be {15*5*6}')
print(len(recordings_only_processed))

recordings_only_processed.to_csv('recordings_processed.csv')




In [None]:
def compute_obj_measures(row, ref_method):
    # Function: Compute objective measures based on one row of a data frame 
    # ----- Input: -----
    # ref_method - method for computing clean reference signal
    # idx - index in the original data frame
    # row - row containing filenames of plus and minus recording
    # ----- Output: -----
    # tuple of objective measures: 
    # - snr_val: snr estimated with phase inversion technique
    # - mbstoi_val: binaural speech intelligibility model
    # - hasqi_left_val: hearing aid speech quality for 1 ear
    # - hasqi_left_val: hearing aid speech perception index for 1 ear
    # - sisdr_val: sudo-rm-rf method for computing sdr
    # ----------------------------------------------------------------------
    # print sth when the row is done:
    FS_TARGET=16e3
    fs, plus = wavfile.read(row['plus_file'])
    fs, minus = wavfile.read(row['minus_file'])
    if fs!=16000:
        plus = resample_poly(plus.astype(np.float32), FS_TARGET, fs)
        minus = resample_poly(minus.astype(np.float32), FS_TARGET, fs)

    
    # ------------- Objective measure : SNR -------------
    # plus, minus =hlp.synch_sigs(plus,minus)
    s=0.5*(plus+minus)
    n=0.5*(plus-minus)
    snr_L= 10 * np.log10(hlp.power(s[:,0]) / hlp.power(n[:,0]))
    snr_R = 10 * np.log10(hlp.power(s[:,1]) / hlp.power(n[:,1]))

    # ---------- Compute clean reference signal for further methods ----------

    
    if ref_method=="it1":
    # Referencing_method 1 ----> our first iteration
        audio_ref=s
    
    elif ref_method=="it2":
    # Referencing_method 2 ----> our second iteration 
        fs, refplus = wavfile.read(row['plus_ref_bypass'])
        fs, refminus = wavfile.read(row['minus_ref_bypass'])
        if fs!=16000:
            refplus = resample_poly(refplus.astype(np.float32), FS_TARGET, fs)
            refminus = resample_poly(refminus.astype(np.float32), FS_TARGET, fs)
        # refplus, refminus =hlp.synch_sigs(refplus,refminus)
        audio_ref=0.5*(refplus+refminus)

    elif ref_method=="it3":
    # Referencing_method 3 ----> our third iteration 
        fs, refplus = wavfile.read(row['plus_ref_ku'])
        fs, refminus = wavfile.read(row['minus_ref_ku'])
        if fs!=16000:
            refplus = resample_poly(refplus.astype(np.float32), FS_TARGET, fs)
            refminus = resample_poly(refminus.astype(np.float32), FS_TARGET, fs)
        # refplus, refminus =hlp.synch_sigs(refplus,refminus)
        audio_ref=0.5*(refplus+refminus)

    audio_ref = audio_ref[:,0]
    audio_test = plus[:,0]
    audio_ref, audio_test = check_size(audio_ref,audio_test)

    
    if metric_type == 0:
        nmr_feat = extract_stft(audio_ref)
        test_feat = extract_stft(audio_test)
    else:
        nmr_feat = audio_ref
        test_feat = audio_test
        
    test_feat = torch.from_numpy(test_feat).float().to(device).unsqueeze(0)
    nmr_feat = torch.from_numpy(nmr_feat).float().to(device).unsqueeze(0)
    if metric_type == 0:
        noresqa_pout, noresqa_qout = model_prediction_noresqa(test_feat, nmr_feat)
        #print('Probaility of the test speech cleaner than the given NMR =', noresqa_pout)
        #print('NORESQA score of the test speech with respect to the given NMR =', noresqa_qout)
        #return noresqa_pout, noresqa_quout
    elif metric_type == 1:
        mos_score = model_prediction_noresqa_mos(test_feat, nmr_feat)
        #print('MOS score of the test speech (assuming NMR is clean) =', str(5.0-mos_score))    #return idx, snr_L, snr_R, hasqi_L, hasqi_R, haspi_L, haspi_R, sisdr_L, sisdr_R, snrloss_L, snrloss_R, mbstoi_B 
        #return mos_score
    
    return mos_score
# Test function compute_obj_measures():
measures=pd.read_csv('recordings_processed.csv')

#compute_obj_measures(120,measures.loc[120],"it1")
#print(f'Measures for {measures.loc[idx,"plus_file"]}\n{snr_L=}\n{hasqi_L=}\n{haspi_L=}\n{sisdr_L=}\n{snrloss_L=}\n{mbstoi_B=}')

# Multiprocessing - compute measures
# The main script has to be in the same cell as the definition of the function
# if __name__ == '__main__':
#     NUM_OF_WORKERS = 8
#     with Pool(NUM_OF_WORKERS) as pool:
#         results = [pool.apply_async(compute_obj_measures, [idx, row, "it3"]) for idx, row in measures.iterrows()]
#         for result in results:
#             idx, snr_L, snr_R, hasqi_L, hasqi_R, haspi_L, haspi_R, sisdr_L, sisdr_R, snrloss_L, snrloss_R, mbstoi_B  = result.get()
#             measures.loc[idx, 'snr_L'] = snr_L
#             measures.loc[idx, 'snr_R'] = snr_R
#             measures.loc[idx, 'haspi_L'] = haspi_L
#             measures.loc[idx, 'haspi_R'] = haspi_R
#             measures.loc[idx, 'hasqi_L'] = hasqi_L
#             measures.loc[idx, 'hasqi_R'] = hasqi_R
#             measures.loc[idx, 'sisdr_L'] = sisdr_L
#             measures.loc[idx, 'sisdr_R'] = sisdr_R
#             measures.loc[idx, 'snrloss_L'] = snrloss_L
#             measures.loc[idx, 'snrloss_R'] = snrloss_R
#             measures.loc[idx, 'mbstoi_B'] = mbstoi_B

# measures.to_csv('objective_measures3_nosynch.csv')


In [None]:
# IT1
mos_it1 = []
for i in tqdm.tqdm(range(len(measures))):
    row = measures.loc[i]
    mos = compute_obj_measures(row, 'it1')
    mos_it1.append(mos)

In [None]:
# IT2
mos_it2 = []
for i in tqdm.tqdm(range(len(measures))):
    row = measures.loc[i]
    mos = compute_obj_measures(row, 'it2')
    mos_it2.append(mos)
    
# IT3
mos_it3 = []
for i in tqdm.tqdm(range(len(measures))):
    row = measures.loc[i]
    mos = compute_obj_measures(row, 'it3')
    mos_it3.append(mos)

In [None]:
measures_computed=pd.read_csv('objective_measures3.csv')

In [None]:
measures_computed['noresqaMOS']=mos_it3

In [None]:
measures_computed

In [None]:
measures_computed=pd.read_csv('objective_measures2.csv')
measures_computed['noresqaMOS']=mos_it2
measures_computed

In [None]:
def synch_sigs(sig1,sig2):
    sig1_out=np.zeros(sig1.shape)
    sig2_out=np.zeros(sig2.shape)
    corr = signal.correlate(sig1[:,0], sig2[:,0], 'full')
    lag = signal.correlation_lags(len(sig1[:,0]), len(sig2[:,0]), mode='full')[np.argmax(corr)]
    if lag > 0:
        sig2=sig2[0:-lag, :]
        sig1=sig1[lag:, :]
    elif lag < 0:
        sig2=sig2[-lag:, :]
        sig1=sig1[0:lag, :]

    sig1_out[:sig1.shape[0],:]=sig1
    sig2_out[:sig2.shape[0],:]=sig2
    return sig1_out,sig2_out

In [None]:
measures_computed=pd.read_csv('objective_measures1.csv')
measures_computed['noresqaMOS']=mos_it1
measures_computed

In [None]:
'''
measures_computed=pd.read_csv('objective_measures3.csv')

# For a better visibility, we add one column which specifies the hearing aid model (without division into different receivers)
ha_models=["gn3","gn5","ph5","ph4","si3"]
for modelname in ha_models:
    measures_computed.loc[measures_computed['device'].str.contains(modelname), "device_group"]=modelname

import seaborn as sns
import warnings
  
# Settings the warnings to be ignored
warnings.filterwarnings('ignore')

def plot_1_measure(df,measure):
    fig, axes = plt.subplots(1, 3, figsize=(25, 5), sharey=True) 
    sns.set(font_scale=0.8)
    sns.stripplot(ax=axes[1],data=df,x="dnn_applied", y=measure, hue="device_group")
    sns.stripplot(ax=axes[2],data=df,x="dnn_applied", y=measure, hue="scene")
    df.boxplot(ax=axes[0],column=measure,by="dnn_applied")
    axes[0].set_xlabel('')
    axes[1].set_xlabel('')
    axes[2].set_xlabel('')
    axes[0].set_title('Standard boxplot')
    axes[1].set_title('Colored by devices')
    axes[2].set_title('Colored by scene')
    axes[0].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[1].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    axes[2].set_xticklabels(['HA','DNN-normal','DNN-mild','DNN-normal_C','DNN-mild_C'], rotation=45)
    fig.suptitle('Measure: '+ measure)
    plt.show()

# Plot all measures:
plot_1_measure(measures_computed,"snr_L")
plot_1_measure(measures_computed,"sisdr_L")
plot_1_measure(measures_computed,"snrloss_L")
plot_1_measure(measures_computed,"hasqi_L")
plot_1_measure(measures_computed,"haspi_L")
plot_1_measure(measures_computed,"mbstoi_B")
''';