In [1]:
import os
import numpy as np
import mtrf
import librosa
import matplotlib.pyplot as plt

from IPython.display import Audio, display
from sklearn.cross_decomposition import CCA

from scipy.io import loadmat
from scipy import linalg
from scipy import stats
from scipy.signal import hilbert, resample
from scipy.stats import zscore, pearsonr

from mtrf.model import TRF
from sklearn.cross_decomposition import CCA

In [2]:
def lag_generator_new(r, lags):
    '''
    Args:
      r: [time, neurons]
      
    Return
      out: [time, neuron*lags]
    
    '''
    lags = list(range(lags[0], lags[1]+1))
    out = np.zeros([r.shape[0], r.shape[1]*len(lags)])
    r = np.pad(r, ((0,len(lags)),(0,0)), 'constant')

    r_lag_list = []
    
    for lag in lags:
        t1 = np.roll(r, lag, axis=0)
        if lag < 0:
            t1[lag-1:, :] = 0
        else:
            t1[:lag, :] = 0
            
        r_lag_list.append(t1[:out.shape[0], :])
        
    out = np.concatenate(r_lag_list, axis=1)
    
    return out


In [3]:
#trials = os.listdir('../../../Data/Cindy/Preprocessed/preprocessed_mixed_new')
#trials.remove('cindy_mixed_pp_record.csv')

folder_name = '../../../Data/Samet/Preprocessed/preprocessed_mixed_01_15Hz'
trials = os.listdir(folder_name)
trials = [item for item in trials if item not in ['multi1_pp_record.csv','multi2_pp_record.csv','multi3_pp_record.csv','multi4_pp_record.csv','.ipynb_checkpoints.mat','.ipynb_checkpoints']]

fs_eeg = 128

lags_neuro = [-40, 10]
lags_stim = [-10, 10]

In [4]:
speech_eeg_all = []
speech_att_env_all = []
speech_unatt_env_all = []

for trial in trials:
    print(trial)
    data = loadmat(os.path.join(folder_name,trial))
    
    if data['stim_attended'][0] == 'Speech':
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
    elif data['stim_attended'][0] == 'Music':
        continue
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        
    
    if data['stim_attended_pos'][0] == 'FirstHalfAttend':
        att_stim = att_stim[:int(len(att_stim)/2)]
        unatt_stim = unatt_stim[:int(len(unatt_stim)/2)]
    elif data['stim_attended_pos'][0] == 'SecondHalfAttend':
        att_stim = att_stim[int(len(att_stim)/2):]
        unatt_stim = unatt_stim[int(len(unatt_stim)/2):]
    
    # display(Audio(att_stim,rate=fs_audio))
    # display(Audio(unatt_stim,rate=fs_audio))
    
    att_env = np.abs(hilbert(att_stim))
    unatt_env = np.abs(hilbert(unatt_stim))
    
    duration_sec = len(att_env) / fs_audio
    n_target_samples = int(duration_sec * fs_eeg)
    att_env = np.expand_dims(resample(att_env, n_target_samples),axis=0)
    
    duration_sec = len(unatt_env) / fs_audio
    n_target_samples = int(duration_sec * fs_eeg)
    unatt_env = np.expand_dims(resample(unatt_env, n_target_samples),axis=0)

    speech_eeg_all.append(data['eeg_data'])
    speech_att_env_all.append(att_env)
    speech_unatt_env_all.append(unatt_env)

samet_Music_4.mat
samet_Music_27.mat
samet_Music_33.mat
samet_Music_32.mat
samet_Music_26.mat
samet_Music_5.mat
samet_Music_7.mat
samet_Music_18.mat
samet_Music_30.mat
samet_Music_24.mat
samet_Music_25.mat
samet_Music_31.mat
samet_Music_19.mat
samet_Music_6.mat
samet_Music_2.mat
samet_Music_35.mat
samet_Music_21.mat
samet_Music_20.mat
samet_Music_34.mat
samet_Music_3.mat
samet_Music_1.mat
samet_Music_22.mat
samet_Music_36.mat
samet_Music_37.mat
samet_Music_23.mat
samet_Music_0.mat
samet_Speech_34.mat
samet_Speech_20.mat
samet_Speech_0.mat
samet_Speech_1.mat
samet_Speech_21.mat
samet_Speech_35.mat
samet_Speech_23.mat
samet_Speech_37.mat
samet_Speech_3.mat
samet_Speech_2.mat
samet_Speech_36.mat
samet_Speech_22.mat
samet_Speech_26.mat
samet_Speech_32.mat
samet_Speech_6.mat
samet_Speech_7.mat
samet_Speech_33.mat
samet_Speech_27.mat
samet_Speech_31.mat
samet_Speech_25.mat
samet_Speech_19.mat
samet_Speech_5.mat
samet_Speech_4.mat
samet_Speech_18.mat
samet_Speech_24.mat
samet_Speech_30.mat
sa

In [5]:
music_eeg_all = []
music_att_env_all = []
music_unatt_env_all = []

for trial in trials:
    print(trial)
    data = loadmat(os.path.join(folder_name,trial))
    
    if data['stim_attended'][0] == 'Speech':
        continue
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
    elif data['stim_attended'][0] == 'Music':
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        
    
    if data['stim_attended_pos'][0] == 'FirstHalfAttend':
        att_stim = att_stim[:int(len(att_stim)/2)]
        unatt_stim = unatt_stim[:int(len(unatt_stim)/2)]
    elif data['stim_attended_pos'][0] == 'SecondHalfAttend':
        att_stim = att_stim[int(len(att_stim)/2):]
        unatt_stim = unatt_stim[int(len(unatt_stim)/2):]
    
    # display(Audio(att_stim,rate=fs_audio))
    # display(Audio(unatt_stim,rate=fs_audio))
    
    att_env = np.abs(hilbert(att_stim))
    unatt_env = np.abs(hilbert(unatt_stim))
    
    duration_sec = len(att_env) / fs_audio
    n_target_samples = int(duration_sec * fs_eeg)
    att_env = np.expand_dims(resample(att_env, n_target_samples),axis=0)
    
    duration_sec = len(unatt_env) / fs_audio
    n_target_samples = int(duration_sec * fs_eeg)
    unatt_env = np.expand_dims(resample(unatt_env, n_target_samples),axis=0)

    music_eeg_all.append(data['eeg_data'])
    music_att_env_all.append(att_env)
    music_unatt_env_all.append(unatt_env)

samet_Music_4.mat
samet_Music_27.mat
samet_Music_33.mat
samet_Music_32.mat
samet_Music_26.mat
samet_Music_5.mat
samet_Music_7.mat
samet_Music_18.mat
samet_Music_30.mat
samet_Music_24.mat
samet_Music_25.mat
samet_Music_31.mat
samet_Music_19.mat
samet_Music_6.mat
samet_Music_2.mat
samet_Music_35.mat
samet_Music_21.mat
samet_Music_20.mat
samet_Music_34.mat
samet_Music_3.mat
samet_Music_1.mat
samet_Music_22.mat
samet_Music_36.mat
samet_Music_37.mat
samet_Music_23.mat
samet_Music_0.mat
samet_Speech_34.mat
samet_Speech_20.mat
samet_Speech_0.mat
samet_Speech_1.mat
samet_Speech_21.mat
samet_Speech_35.mat
samet_Speech_23.mat
samet_Speech_37.mat
samet_Speech_3.mat
samet_Speech_2.mat
samet_Speech_36.mat
samet_Speech_22.mat
samet_Speech_26.mat
samet_Speech_32.mat
samet_Speech_6.mat
samet_Speech_7.mat
samet_Speech_33.mat
samet_Speech_27.mat
samet_Speech_31.mat
samet_Speech_25.mat
samet_Speech_19.mat
samet_Speech_5.mat
samet_Speech_4.mat
samet_Speech_18.mat
samet_Speech_24.mat
samet_Speech_30.mat
sa

In [6]:
speech_eeg = np.concatenate(speech_eeg_all,axis=1).T
speech_stim_att = np.concatenate(speech_att_env_all,axis=1).T
speech_stim_unatt = np.concatenate(speech_unatt_env_all,axis=1).T
speech_eeg = zscore(speech_eeg, axis=0)
speech_stim_att = zscore(speech_stim_att, axis=0)
speech_stim_unatt = zscore(speech_stim_unatt, axis=0)

music_eeg = np.concatenate(music_eeg_all,axis=1).T
music_stim_att = np.concatenate(music_att_env_all,axis=1).T
music_stim_unatt = np.concatenate(music_unatt_env_all,axis=1).T
music_eeg = zscore(music_eeg, axis=0)
music_stim_att = zscore(music_stim_att, axis=0)
music_stim_unatt = zscore(music_stim_unatt, axis=0)

In [7]:
speech_model_train_corrs = []
speech_model_speech_att_test_corrs = []
speech_model_speech_unatt_test_corrs = []
speech_model_music_att_test_corrs = []
speech_model_music_unatt_test_corrs = []

music_model_train_corrs = []
music_model_speech_att_test_corrs = []
music_model_speech_unatt_test_corrs = []
music_model_music_att_test_corrs = []
music_model_music_unatt_test_corrs = []

speech_sample_len = speech_eeg.shape[0]
music_sample_len = music_eeg.shape[0]

k_cv = 60
for i in range(k_cv):
    print(f'Split {i+1}')

    #Train Test Split
    
    speech_eeg_test = speech_eeg[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]
    speech_stim_att_test = speech_stim_att[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]
    speech_stim_unatt_test = speech_stim_unatt[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]

    music_eeg_test = music_eeg[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]
    music_stim_att_test = music_stim_att[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]
    music_stim_unatt_test = music_stim_unatt[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]

    speech_eeg_train = np.concatenate((speech_eeg[:i*(round(speech_sample_len/k_cv)),:],speech_eeg[(i+1)*(round(speech_sample_len/k_cv)):,:]),axis=0)
    speech_stim_train = np.concatenate((speech_stim_att[:i*(round(speech_sample_len/k_cv)),:],speech_stim_att[(i+1)*(round(speech_sample_len/k_cv)):,:]),axis=0)

    music_eeg_train = np.concatenate((music_eeg[:i*(round(music_sample_len/k_cv)),:],music_eeg[(i+1)*(round(music_sample_len/k_cv)):,:]),axis=0)
    music_stim_train = np.concatenate((music_stim_att[:i*(round(music_sample_len/k_cv)),:],music_stim_att[(i+1)*(round(music_sample_len/k_cv)):,:]),axis=0)
    
    #Lags
    
    speech_eeg_train = lag_generator_new(speech_eeg_train,lags_neuro)
    speech_stim_train = lag_generator_new(speech_stim_train,lags_stim)

    music_eeg_train = lag_generator_new(music_eeg_train,lags_neuro)
    music_stim_train = lag_generator_new(music_stim_train,lags_stim)
    
    speech_eeg_test = lag_generator_new(speech_eeg_test,lags_neuro)
    speech_stim_att_test = lag_generator_new(speech_stim_att_test,lags_stim)
    speech_stim_unatt_test = lag_generator_new(speech_stim_unatt_test,lags_stim)

    music_eeg_test = lag_generator_new(music_eeg_test,lags_neuro)
    music_stim_att_test = lag_generator_new(music_stim_att_test,lags_stim)
    music_stim_unatt_test = lag_generator_new(music_stim_unatt_test,lags_stim)

    #Training
    
    cca_speech_att = CCA(n_components=3)
    cca_speech_att = cca_speech_att.fit(speech_eeg_train, speech_stim_train)

    cca_music_att = CCA(n_components=3)
    cca_music_att = cca_music_att.fit(music_eeg_train, music_stim_train)

    #Evaluations

    print("Speech-Trained Model")
    
    X_c, Y_c = cca_speech_att.transform(speech_eeg_train, speech_stim_train)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Train: {r_fwd.round(3)}")
    speech_model_train_corrs.append(r_fwd)

    X_c, Y_c = cca_speech_att.transform(speech_eeg_test, speech_stim_att_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Attended Speech: {r_fwd.round(3)}")
    speech_model_speech_att_test_corrs.append(r_fwd)
    
    X_c, Y_c = cca_speech_att.transform(music_eeg_test, music_stim_unatt_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Unattended Speech: {r_fwd.round(3)}")
    speech_model_music_unatt_test_corrs.append(r_fwd)

    X_c, Y_c = cca_speech_att.transform(music_eeg_test, music_stim_att_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Attended Music: {r_fwd.round(3)}")
    speech_model_music_att_test_corrs.append(r_fwd)

    X_c, Y_c = cca_speech_att.transform(speech_eeg_test, speech_stim_unatt_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Unattended Music: {r_fwd.round(3)}")
    speech_model_speech_unatt_test_corrs.append(r_fwd)


    print("Music-Trained Model")
    
    X_c, Y_c = cca_music_att.transform(music_eeg_train, music_stim_train)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Train: {r_fwd.round(3)}")
    music_model_train_corrs.append(r_fwd)

    X_c, Y_c = cca_music_att.transform(speech_eeg_test, speech_stim_att_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Attended Speech: {r_fwd.round(3)}")
    music_model_speech_att_test_corrs.append(r_fwd)
    
    X_c, Y_c = cca_music_att.transform(music_eeg_test, music_stim_unatt_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Unattended Speech: {r_fwd.round(3)}")
    music_model_music_unatt_test_corrs.append(r_fwd)

    X_c, Y_c = cca_music_att.transform(music_eeg_test, music_stim_att_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Attended Music: {r_fwd.round(3)}")
    music_model_music_att_test_corrs.append(r_fwd)

    X_c, Y_c = cca_music_att.transform(speech_eeg_test, speech_stim_unatt_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Unattended Music: {r_fwd.round(3)}")
    music_model_speech_unatt_test_corrs.append(r_fwd)

print("Speech-Trained Model")
print(f'Average Training Correlation: {np.mean(speech_model_train_corrs)}')
print(f'Average Attended Speech Test Correlation: {np.mean(speech_model_speech_att_test_corrs)}')
print(f'Average Unttended Speech Test Correlation: {np.mean(speech_model_music_unatt_test_corrs)}')
print(f'Average Attended Music Test Correlation: {np.mean(speech_model_music_att_test_corrs)}')
print(f'Average Unttended Music Test Correlation: {np.mean(speech_model_speech_unatt_test_corrs)}')

print("Music-Trained Model")
print(f'Average Training Correlation: {np.mean(music_model_train_corrs)}')
print(f'Average Attended Speech Test Correlation: {np.mean(music_model_speech_att_test_corrs)}')
print(f'Average Unttended Speech Test Correlation: {np.mean(music_model_music_unatt_test_corrs)}')
print(f'Average Attended Music Test Correlation: {np.mean(music_model_music_att_test_corrs)}')
print(f'Average Unttended Music Test Correlation: {np.mean(music_model_speech_unatt_test_corrs)}')

Split 1
Speech-Trained Model
Train: 0.196
Attended Speech: 0.036
Unattended Speech: 0.111
Attended Music: 0.108
Unattended Music: 0.004
Music-Trained Model
Train: 0.14
Attended Speech: 0.045
Unattended Speech: 0.062
Attended Music: 0.057
Unattended Music: -0.016
Split 2
Speech-Trained Model
Train: 0.192
Attended Speech: 0.155
Unattended Speech: -0.058
Attended Music: 0.09
Unattended Music: -0.087
Music-Trained Model
Train: 0.136
Attended Speech: 0.006
Unattended Speech: 0.036
Attended Music: 0.17
Unattended Music: 0.003
Split 3
Speech-Trained Model
Train: 0.192
Attended Speech: 0.152
Unattended Speech: 0.11
Attended Music: -0.0
Unattended Music: 0.008
Music-Trained Model
Train: 0.144
Attended Speech: 0.021
Unattended Speech: -0.016
Attended Music: 0.013
Unattended Music: 0.052
Split 4
Speech-Trained Model
Train: 0.197
Attended Speech: -0.09
Unattended Speech: 0.14
Attended Music: -0.065
Unattended Music: 0.036
Music-Trained Model
Train: 0.139
Attended Speech: 0.004
Unattended Speech: -

In [15]:
print("Speech-Trained Model")
print(f"Speech-Attended Music-Unattended AAD Accuracy: {np.mean([True if speech_model_speech_att_test_corrs[i] > speech_model_speech_unatt_test_corrs[i] else False for i in range(len(speech_model_speech_att_test_corrs))])}")
print(f"Music-Attended Speech-Unattended AAD Accuracy: {np.mean([True if speech_model_music_att_test_corrs[i] > speech_model_music_unatt_test_corrs[i] else False for i in range(len(speech_model_music_att_test_corrs))])}")

print("Music-Trained Model")
print(f"Speech-Attended Music-Unattended AAD Accuracy: {np.mean([True if music_model_speech_att_test_corrs[i] > music_model_speech_unatt_test_corrs[i] else False for i in range(len(music_model_speech_att_test_corrs))])}")
print(f"Music-Attended Speech-Unattended AAD Accuracy: {np.mean([True if music_model_music_att_test_corrs[i] > music_model_music_unatt_test_corrs[i] else False for i in range(len(music_model_music_att_test_corrs))])}")

print("Average")
print(f"Speech-Attended Music-Unattended AAD Accuracy: {np.mean([True if speech_model_speech_att_test_corrs[i] + music_model_speech_att_test_corrs[i] > speech_model_speech_unatt_test_corrs[i] + music_model_speech_unatt_test_corrs[i] else False for i in range(len(music_model_speech_att_test_corrs))])}")
print(f"Music-Attended Speech-Unattended AAD Accuracy: {np.mean([True if speech_model_music_att_test_corrs[i] + music_model_music_att_test_corrs[i] > speech_model_music_unatt_test_corrs[i] + music_model_music_unatt_test_corrs[i] else False for i in range(len(music_model_music_att_test_corrs))])}")

print("Weighted Average")
weight = 3.5
print(f"Weight: {weight}")
print(f"Speech-Attended Music-Unattended AAD Accuracy: {np.mean([True if speech_model_speech_att_test_corrs[i] + music_model_speech_att_test_corrs[i] * weight > speech_model_speech_unatt_test_corrs[i] + music_model_speech_unatt_test_corrs[i] * weight else False for i in range(len(music_model_speech_att_test_corrs))])}")
print(f"Music-Attended Speech-Unattended AAD Accuracy: {np.mean([True if speech_model_music_att_test_corrs[i] + music_model_music_att_test_corrs[i] * weight > speech_model_music_unatt_test_corrs[i] + music_model_music_unatt_test_corrs[i] * weight else False for i in range(len(music_model_music_att_test_corrs))])}")

Speech-Trained Model
Speech-Attended Music-Unattended AAD Accuracy: 0.75
Music-Attended Speech-Unattended AAD Accuracy: 0.35
Music-Trained Model
Speech-Attended Music-Unattended AAD Accuracy: 0.48333333333333334
Music-Attended Speech-Unattended AAD Accuracy: 0.75
Average
Speech-Attended Music-Unattended AAD Accuracy: 0.75
Music-Attended Speech-Unattended AAD Accuracy: 0.43333333333333335
Weighted Average
Weight: 3.5
Speech-Attended Music-Unattended AAD Accuracy: 0.6833333333333333
Music-Attended Speech-Unattended AAD Accuracy: 0.6166666666666667
