In [1]:
import os
import numpy as np
import mtrf
import librosa
import matplotlib.pyplot as plt
import pandas as pd

from IPython.display import Audio, display
from sklearn.cross_decomposition import CCA

from scipy.io import loadmat
from scipy import linalg
from scipy import stats
from scipy.signal import hilbert, resample, correlate
from scipy.stats import zscore, pearsonr

from mtrf.model import TRF
from sklearn.cross_decomposition import CCA

In [2]:
def lag_generator_new(r, lags):
    '''
    Args:
      r: [time, neurons]
      
    Return
      out: [time, neuron*lags]
    
    '''
    lags = list(range(lags[0], lags[1]+1))
    out = np.zeros([r.shape[0], r.shape[1]*len(lags)])
    r = np.pad(r, ((0,len(lags)),(0,0)), 'constant')

    r_lag_list = []
    
    for lag in lags:
        t1 = np.roll(r, lag, axis=0)
        if lag < 0:
            t1[lag-1:, :] = 0
        else:
            t1[:lag, :] = 0
            
        r_lag_list.append(t1[:out.shape[0], :])
        
    out = np.concatenate(r_lag_list, axis=1)
    
    return out


In [3]:
#trials = os.listdir('../../../Data/Cindy/Preprocessed/preprocessed_mixed_new')
#trials.remove('cindy_mixed_pp_record.csv')

folder_name = '../../../Data/Cindy/Preprocessed/preprocessed_mixed_01_30Hz'
trials = os.listdir(folder_name)
trials = [item for item in trials if item not in ['cindy_mixed_pp_record.csv','.ipynb_checkpoints.mat','.ipynb_checkpoints']]


# folder_name = '../../../Data/Samet/Preprocessed/preprocessed_mixed_01_15Hz'
# trials = os.listdir(folder_name)
# trials = [item for item in trials if item not in ['multi1_pp_record.csv','multi2_pp_record.csv','multi3_pp_record.csv','multi4_pp_record.csv','.ipynb_checkpoints.mat','.ipynb_checkpoints']]

fs_eeg = 128

lags_neuro = [-40, 10]
lags_stim = [-10, 10]

In [4]:
speech_eeg_all = []
# speech_att_env_all = []
# speech_unatt_env_all = []

speech_unatt_surprisal_all = []

# Load long audio
long_audio, long_sr = librosa.load(f'../../../Stimuli/Cindy/piano_4.wav', sr=None)

surprisal_feature = pd.read_csv('../../../Stimuli/Cindy/Surprisal/piano_4.csv')
surprisal_feature = np.array(surprisal_feature['surprise'].to_list())

for trial in trials:
    print(trial)
    data = loadmat(os.path.join(folder_name,trial))
    
    if data['stim_attended'][0] == 'Speech':
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
    elif data['stim_attended'][0] == 'Music':
        continue
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        
    
    if data['stim_attended_pos'][0] == 'FirstHalfAttend':
        att_stim = att_stim[:int(len(att_stim)/2)]
        unatt_stim = unatt_stim[:int(len(unatt_stim)/2)]
    elif data['stim_attended_pos'][0] == 'SecondHalfAttend':
        att_stim = att_stim[int(len(att_stim)/2):]
        unatt_stim = unatt_stim[int(len(unatt_stim)/2):]
    
    # Load short audio at its native rate (22050 Hz)
    short_audio = np.squeeze(unatt_stim)
    short_sr = 22050
    
    # Resample short audio to match long audio's sampling rate
    short_audio_resampled = librosa.resample(short_audio, orig_sr=short_sr, target_sr=long_sr)
    
    # Normalize both signals
    long_audio = (long_audio - np.mean(long_audio)) / np.std(long_audio)
    short_audio_resampled = (short_audio_resampled - np.mean(short_audio_resampled)) / np.std(short_audio_resampled)
    
    # Cross-correlation to find best match
    correlation = correlate(long_audio, short_audio_resampled, mode='valid')
    best_match_index = np.argmax(correlation)
    end_index = best_match_index + len(short_audio_resampled)

    print(f"Best match found at index range: {best_match_index} to {end_index} (Corr: {np.max(correlation)})")

    unatt_surprisal = np.expand_dims(surprisal_feature[round((best_match_index/long_sr*128)):round((end_index/long_sr*128))],axis=0)
    
    # att_env = np.abs(hilbert(att_stim))
    # unatt_env = np.abs(hilbert(unatt_stim))
    
    # duration_sec = len(att_env) / fs_audio
    # n_target_samples = int(duration_sec * fs_eeg)
    # att_env = np.expand_dims(resample(att_env, n_target_samples),axis=0)
    
    # duration_sec = len(unatt_env) / fs_audio
    # n_target_samples = int(duration_sec * fs_eeg)
    # unatt_env = np.expand_dims(resample(unatt_env, n_target_samples),axis=0)

    speech_eeg_all.append(data['eeg_data'])
    speech_unatt_surprisal_all.append(unatt_surprisal)
    # speech_att_env_all.append(att_env)
    # speech_unatt_env_all.append(unatt_env)

cindy_mixed_Music_32.mat
cindy_mixed_Music_26.mat
cindy_mixed_Music_27.mat
cindy_mixed_Music_33.mat
cindy_mixed_Music_25.mat
cindy_mixed_Music_31.mat
cindy_mixed_Music_19.mat
cindy_mixed_Music_18.mat
cindy_mixed_Music_30.mat
cindy_mixed_Music_24.mat
cindy_mixed_Music_20.mat
cindy_mixed_Music_34.mat
cindy_mixed_Music_35.mat
cindy_mixed_Music_21.mat
cindy_mixed_Music_37.mat
cindy_mixed_Music_23.mat
cindy_mixed_Music_22.mat
cindy_mixed_Music_36.mat
cindy_mixed_Music_8.mat
cindy_mixed_Speech_1.mat
Best match found at index range: 3969000 to 5292000 (Corr: 1734870.75)
cindy_mixed_Speech_17.mat
Best match found at index range: 104004956 to 105327956 (Corr: 1941546.375)
cindy_mixed_Speech_16.mat
Best match found at index range: 100035956 to 101358956 (Corr: 1572415.875)
cindy_mixed_Speech_0.mat
Best match found at index range: 1323000 to 2646000 (Corr: 1772553.875)
cindy_mixed_Music_9.mat
cindy_mixed_Speech_2.mat
Best match found at index range: 5292000 to 6615000 (Corr: 1696999.75)
cindy_mix

In [5]:
music_eeg_all = []
# music_att_env_all = []
# music_unatt_env_all = []

music_att_surprisal_all = []

for trial in trials:
    print(trial)
    data = loadmat(os.path.join(folder_name,trial))
    
    if data['stim_attended'][0] == 'Speech':
        continue
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
    elif data['stim_attended'][0] == 'Music':
        att_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/piano_only_long_cropped_22khz",f"{data['stimuli_music'][0]}"+'.wav'))
        unatt_stim,fs_audio = librosa.load(os.path.join("../../../Stimuli/Cindy/speech_only_short_22khz",f"{data['stimuli_speech'][0]}"+'.wav'))
        
    
    if data['stim_attended_pos'][0] == 'FirstHalfAttend':
        att_stim = att_stim[:int(len(att_stim)/2)]
        unatt_stim = unatt_stim[:int(len(unatt_stim)/2)]
    elif data['stim_attended_pos'][0] == 'SecondHalfAttend':
        att_stim = att_stim[int(len(att_stim)/2):]
        unatt_stim = unatt_stim[int(len(unatt_stim)/2):]
    
    # display(Audio(att_stim,rate=fs_audio))
    # display(Audio(unatt_stim,rate=fs_audio))

    # Load short audio at its native rate (22050 Hz)
    short_audio = np.squeeze(att_stim)
    short_sr = 22050
    
    # Resample short audio to match long audio's sampling rate
    short_audio_resampled = librosa.resample(short_audio, orig_sr=short_sr, target_sr=long_sr)
    
    # Normalize both signals
    long_audio = (long_audio - np.mean(long_audio)) / np.std(long_audio)
    short_audio_resampled = (short_audio_resampled - np.mean(short_audio_resampled)) / np.std(short_audio_resampled)
    
    # Cross-correlation to find best match
    correlation = correlate(long_audio, short_audio_resampled, mode='valid')
    best_match_index = np.argmax(correlation)
    end_index = best_match_index + len(short_audio_resampled)

    print(f"Best match found at index range: {best_match_index} to {end_index} (Corr: {np.max(correlation)})")

    att_surprisal = np.expand_dims(surprisal_feature[round((best_match_index/long_sr*128)):round((end_index/long_sr*128))],axis=0)

    
    # att_env = np.abs(hilbert(att_stim))
    # unatt_env = np.abs(hilbert(unatt_stim))
    
    # duration_sec = len(att_env) / fs_audio
    # n_target_samples = int(duration_sec * fs_eeg)
    # att_env = np.expand_dims(resample(att_env, n_target_samples),axis=0)
    
    # duration_sec = len(unatt_env) / fs_audio
    # n_target_samples = int(duration_sec * fs_eeg)
    # unatt_env = np.expand_dims(resample(unatt_env, n_target_samples),axis=0)

    music_eeg_all.append(data['eeg_data'])
    music_att_surprisal_all.append(att_surprisal)
    # music_att_env_all.append(att_env)
    # music_unatt_env_all.append(unatt_env)

cindy_mixed_Music_32.mat
Best match found at index range: 92097956 to 93420956 (Corr: 1169271.0)
cindy_mixed_Music_26.mat
Best match found at index range: 19845001 to 21168001 (Corr: 1842736.625)
cindy_mixed_Music_27.mat
Best match found at index range: 22491001 to 23814001 (Corr: 1612999.5)
cindy_mixed_Music_33.mat
Best match found at index range: 96066956 to 97389956 (Corr: 1682389.5)
cindy_mixed_Music_25.mat
Best match found at index range: 15876001 to 17199001 (Corr: 1915513.75)
cindy_mixed_Music_31.mat
Best match found at index range: 90774956 to 92097956 (Corr: 1774320.75)
cindy_mixed_Music_19.mat
Best match found at index range: 1323000 to 2646000 (Corr: 1772553.875)
cindy_mixed_Music_18.mat
Best match found at index range: 105327956 to 106650956 (Corr: 1565162.125)
cindy_mixed_Music_30.mat
Best match found at index range: 88128956 to 89451956 (Corr: 1596560.125)
cindy_mixed_Music_24.mat
Best match found at index range: 13230000 to 14553000 (Corr: 1098550.0)
cindy_mixed_Music_20

In [6]:
speech_eeg = np.concatenate(speech_eeg_all,axis=1).T
# speech_stim_att = np.concatenate(speech_att_env_all,axis=1).T
# speech_stim_unatt = np.concatenate(speech_unatt_env_all,axis=1).T
speech_stim_unatt = np.concatenate(speech_unatt_surprisal_all,axis=1).T
speech_eeg = zscore(speech_eeg, axis=0)
# speech_stim_att = zscore(speech_stim_att, axis=0)
speech_stim_unatt = zscore(speech_stim_unatt, axis=0)

music_eeg = np.concatenate(music_eeg_all,axis=1).T
# music_stim_att = np.concatenate(music_att_env_all,axis=1).T
# music_stim_unatt = np.concatenate(music_unatt_env_all,axis=1).T
music_stim_att = np.concatenate(music_att_surprisal_all,axis=1).T
music_eeg = zscore(music_eeg, axis=0)
music_stim_att = zscore(music_stim_att, axis=0)
# music_stim_unatt = zscore(music_stim_unatt, axis=0)

In [7]:
train_corrs = []
# speech_att_test_corrs = []
speech_unatt_test_corrs = []
music_att_test_corrs = []
# music_unatt_test_corrs = []


speech_sample_len = speech_eeg.shape[0]
music_sample_len = music_eeg.shape[0]

k_cv = 20
for i in range(k_cv):
    print(f'Split {i+1}')

    #Train Test Split
    
    speech_eeg_test = speech_eeg[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]
    # speech_stim_att_test = speech_stim_att[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]
    speech_stim_unatt_test = speech_stim_unatt[i*(round(speech_sample_len/k_cv)):(i+1)*(round(speech_sample_len/k_cv)),:]

    music_eeg_test = music_eeg[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]
    music_stim_att_test = music_stim_att[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]
    # music_stim_unatt_test = music_stim_unatt[i*(round(music_sample_len/k_cv)):(i+1)*(round(music_sample_len/k_cv)),:]

    # speech_eeg_train = np.concatenate((speech_eeg[:i*(round(speech_sample_len/k_cv)),:],speech_eeg[(i+1)*(round(speech_sample_len/k_cv)):,:]),axis=0)
    # speech_stim_train = np.concatenate((speech_stim_att[:i*(round(speech_sample_len/k_cv)),:],speech_stim_att[(i+1)*(round(speech_sample_len/k_cv)):,:]),axis=0)

    music_eeg_train = np.concatenate((music_eeg[:i*(round(music_sample_len/k_cv)),:],music_eeg[(i+1)*(round(music_sample_len/k_cv)):,:]),axis=0)
    music_stim_train = np.concatenate((music_stim_att[:i*(round(music_sample_len/k_cv)),:],music_stim_att[(i+1)*(round(music_sample_len/k_cv)):,:]),axis=0)

    eeg_train = music_eeg_train
    stim_train = music_stim_train
    
    #Lags
    
    eeg_train = lag_generator_new(eeg_train,lags_neuro)
    stim_train = lag_generator_new(stim_train,lags_stim)
    
    speech_eeg_test = lag_generator_new(speech_eeg_test,lags_neuro)
    # speech_stim_att_test = lag_generator_new(speech_stim_att_test,lags_stim)
    speech_stim_unatt_test = lag_generator_new(speech_stim_unatt_test,lags_stim)

    music_eeg_test = lag_generator_new(music_eeg_test,lags_neuro)
    music_stim_att_test = lag_generator_new(music_stim_att_test,lags_stim)
    # music_stim_unatt_test = lag_generator_new(music_stim_unatt_test,lags_stim)

    #Training
    
    cca_att = CCA(n_components=3)
    cca_att = cca_att.fit(eeg_train, stim_train)

    #Evaluations
    
    X_c, Y_c = cca_att.transform(eeg_train, stim_train)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Train: {r_fwd.round(3)}")
    train_corrs.append(r_fwd)

    # X_c, Y_c = cca_att.transform(speech_eeg_test, speech_stim_att_test)
    # r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    # print(f"Attended Speech: {r_fwd.round(3)}")
    # speech_att_test_corrs.append(r_fwd)
    
    # X_c, Y_c = cca_att.transform(music_eeg_test, music_stim_unatt_test)
    # r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    # print(f"Unattended Speech: {r_fwd.round(3)}")
    # music_unatt_test_corrs.append(r_fwd)

    X_c, Y_c = cca_att.transform(music_eeg_test, music_stim_att_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Attended Music: {r_fwd.round(3)}")
    music_att_test_corrs.append(r_fwd)

    X_c, Y_c = cca_att.transform(speech_eeg_test, speech_stim_unatt_test)
    r_fwd = pearsonr(np.squeeze(X_c.flatten()), np.squeeze(Y_c.flatten())).statistic
    print(f"Unattended Music: {r_fwd.round(3)}")
    speech_unatt_test_corrs.append(r_fwd)


print(f'Average Training Correlation: {np.mean(train_corrs)}')
# print(f'Average Attended Speech Test Correlation: {np.mean(speech_att_test_corrs)}')
# print(f'Average Unttended Speech Test Correlation: {np.mean(music_unatt_test_corrs)}')
print(f'Average Attended Music Test Correlation: {np.mean(music_att_test_corrs)}')
print(f'Average Unttended Music Test Correlation: {np.mean(speech_unatt_test_corrs)}')

Split 1
Train: 0.229
Attended Music: -0.106
Unattended Music: 0.018
Split 2
Train: 0.218
Attended Music: 0.155
Unattended Music: -0.384
Split 3
Train: 0.222
Attended Music: -0.007
Unattended Music: -0.034
Split 4
Train: 0.149
Attended Music: 0.163
Unattended Music: -0.044
Split 5
Train: 0.214
Attended Music: -0.033
Unattended Music: -0.153
Split 6
Train: 0.227
Attended Music: 0.046
Unattended Music: -0.026
Split 7
Train: 0.236
Attended Music: -0.035
Unattended Music: 0.072
Split 8
Train: 0.25
Attended Music: -0.08
Unattended Music: 0.136
Split 9
Train: 0.249
Attended Music: -0.114
Unattended Music: -0.125
Split 10
Train: 0.22
Attended Music: 0.189
Unattended Music: 0.077
Split 11
Train: 0.235
Attended Music: 0.114
Unattended Music: -0.134
Split 12
Train: 0.232
Attended Music: -0.051
Unattended Music: -0.208
Split 13
Train: 0.219
Attended Music: 0.046
Unattended Music: -0.004
Split 14
Train: 0.225
Attended Music: 0.066
Unattended Music: -0.091
Split 15
Train: 0.224
Attended Music: 0.041

In [8]:
# print(f"Speech-Attended Music-Unattended AAD Accuracy: {np.mean([True if speech_att_test_corrs[i] > speech_unatt_test_corrs[i] else False for i in range(len(speech_att_test_corrs))])}")
# print(f"Music-Attended Speech-Unattended AAD Accuracy: {np.mean([True if music_att_test_corrs[i] > music_unatt_test_corrs[i] else False for i in range(len(music_att_test_corrs))])}")

In [9]:
print(f'Average Attended Music Test Correlation: {np.mean(music_att_test_corrs)}')
print(f'Std Attended Music Test Correlation: {np.std(music_att_test_corrs)}')
print(f'Average Unattended Music Test Correlation: {np.mean(speech_unatt_test_corrs)}')
print(f'Std Unattended Music Test Correlation: {np.std(speech_unatt_test_corrs)}')

Average Attended Music Test Correlation: 0.025633235703509272
Std Attended Music Test Correlation: 0.0927111830522191
Average Unattended Music Test Correlation: -0.04245280189458196
Std Unattended Music Test Correlation: 0.11834964060506413


In [10]:
# save to disk
# np.save('Weights/CCA_Multi_Music_Train_Envelope_X_Weights.npy', cca_att.x_weights_)
# np.save('Weights/CCA_Multi_Music_Train_Envelope_Y_Weights.npy', cca_att.y_weights_)