In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import time
import gc
from scipy.stats import zscore
from scipy.ndimage import gaussian_filter1d
import mne
from mne_connectivity import spectral_connectivity_epochs
#truncate mne print
mne.set_log_level('WARNING')
import copy

In [2]:
def smooth_with_gaussian(data, sigma=3):
    return gaussian_filter1d(data, sigma=sigma, axis=1) 

def preprocess(data):
    stds = np.std(data[:, :], axis=0)
    non_constant_cols = stds.astype(float) > 1e-6    #finds the time points where std is not 0
    const_cols = stds.astype(float) <= 1e-6    #finds the time points where std is 0

    z = np.zeros_like(data[:, :])   #creates an array of zeros with the same shape as the data
    z[:, non_constant_cols] = zscore(data[:, non_constant_cols], axis=0)  #in the columns where std is not 0, zscores the data
    z[:, const_cols] = np.mean(data[:, const_cols], axis=0)

 
    if np.isnan(z).any():
        raise ValueError("Data contains NaN values after normalization.")

    return z

run_time = time.time()
file_numbers = [1, 2, 4, 8, 14, 15, 20, 23]

classes_left = np.arange(0, 20)
classes_right = np.arange(0, 20)

agg_structure = {cl: 
                  {cr: [] for cr in classes_right}
                    for cl in classes_left
                } 

l_al_data = copy.deepcopy(agg_structure)
r_al_data = copy.deepcopy(agg_structure)
a_al_data = copy.deepcopy(agg_structure)

l_ar_data = copy.deepcopy(agg_structure)
r_ar_data = copy.deepcopy(agg_structure)
a_ar_data = copy.deepcopy(agg_structure)


n_values = {cl:
             {cr: {
                'attleft': [],
                'attright': []
                } for cr in classes_right} for cl in classes_left} 

#loading the data by class combinations
total_time = time.time()
total_load_time = time.time()
for file_number in file_numbers:   #will first load the file and extract the data
    file_path = f'C:/Users/joshu/PartIIIProject/RSNNdale_attention_{file_number}_attention_test'
    load_data_start_time = time.time()
    data = pickle.load(open(file_path, 'rb'))
    elapsed_time = time.time() - load_data_start_time
    print(f"Dataset {file_number} loaded in {elapsed_time:.2f} seconds")
    file_process_time = time.time()

    attention_labels = data['label_attend'][0]
    label_left = data['label_left'][0]
    label_right = data['label_right'][0]
    attend_01 = data['attend'][0]
    omitted = data['omit'][0]
    relevant = np.where(omitted ==0)[0]

    
    left_input_SP = data['SP'][0][0][relevant]
    right_input_SP = data['SP'][0][1][relevant]
    attention_SP = data['SP'][0][2][relevant]

    sigma = 2

    left_sm = smooth_with_gaussian(left_input_SP, sigma=sigma) 
    right_sm = smooth_with_gaussian(right_input_SP, sigma=sigma) 
    att_sm = smooth_with_gaussian(attention_SP, sigma=sigma) 

    num_trials, num_samples, num_neurons = left_input_SP.shape
    num_neurons_attention = 80
                
    for j in range(0, num_trials):
        for i in range(0, num_neurons):
            count_left = np.count_nonzero(left_input_SP[j, :, i] == 1)
            if count_left > 0:
                left_sm[j, :, i] /= count_left
            count_right = np.count_nonzero(right_input_SP[j, :, i] == 1)
            if count_right > 0:
                right_sm[j, :, i] /= count_right


        for i in range(0, num_neurons_attention):
            count_attention = np.count_nonzero(attention_SP[j, :, i] == 1)
            if count_attention > 0:
                att_sm[j, :, i] /= count_attention


    left_input_SP = np.sum(left_sm, axis=2)
    right_input_SP = np.sum(right_sm, axis=2)
    attention_SP = np.sum(att_sm, axis=2)

    #preprocess here now that we have traces of all of the relavant trials
    left_input_SP = preprocess(left_input_SP)
    right_input_SP = preprocess(right_input_SP)
    attention_SP = preprocess(attention_SP)

    #so now for each dataset we have a preprocessed set of LFP like signals
    for cl in classes_left:
         for cr in classes_right: 
    
            left_indices_agg = np.where((omitted == 0) & (attend_01 == 0) & (label_left == cl) & (label_right == cr))[0]
            left_indices = np.where(np.isin(relevant, left_indices_agg))[0]
            right_indices_agg = np.where((omitted ==0) & (attend_01 == 1) & (label_left == cl) & (label_right == cr))[0]
            right_indices = np.where(np.isin(relevant, right_indices_agg))[0]

            n_values[cl][cr][f'attleft'] = [len(left_indices)]  #e.g. checks number of trials of this class combo for attentino left
            n_values[cl][cr][f'attright'] = [len(right_indices)]  #nb need both left and right indices to be non empty
          
            if len(left_indices) >= 1: 
                l_al = left_input_SP[left_indices, 100:350]
                r_al = right_input_SP[left_indices, 100:350]
                a_al = attention_SP[left_indices, 100:350] 

                l_al_data[cl][cr].append(l_al)
                r_al_data[cl][cr].append(r_al)
                a_al_data[cl][cr].append(a_al)


            if len(right_indices) >= 1:
                l_ar = left_input_SP[right_indices, 100:350]
                r_ar = right_input_SP[right_indices, 100:350]
                a_ar = attention_SP[right_indices, 100:350]
            
                l_ar_data[cl][cr].append(l_ar)   
                r_ar_data[cl][cr].append(r_ar)
                a_ar_data[cl][cr].append(a_ar)

    del data
    gc.collect()
    elapsed_time = time.time() - file_process_time
    print(f"Dataset {file_number} processed in {elapsed_time:.2f} seconds")


#dont want to process here - want to process when its in the big array
print(f'Total load time = {time.time() - total_load_time:.2f} seconds')




Dataset 1 loaded in 13.29 seconds
Dataset 1 processed in 8.91 seconds
Dataset 2 loaded in 16.55 seconds
Dataset 2 processed in 9.95 seconds
Dataset 4 loaded in 29.44 seconds
Dataset 4 processed in 16.95 seconds
Dataset 8 loaded in 21.64 seconds
Dataset 8 processed in 17.26 seconds
Dataset 14 loaded in 14.71 seconds


KeyboardInterrupt: 

In [None]:
for cl in classes_left:
    for cr in classes_right:
        if len(l_al_data[cl][cr]) > 0:
            l_al_data[cl][cr] = np.concatenate(l_al_data[cl][cr], axis = 0)
            r_al_data[cl][cr] = np.concatenate(r_al_data[cl][cr], axis = 0)
            a_al_data[cl][cr] = np.concatenate(a_al_data[cl][cr], axis = 0)

            l_ar_data[cl][cr] = np.concatenate(l_ar_data[cl][cr], axis = 0)
            r_ar_data[cl][cr] = np.concatenate(r_ar_data[cl][cr], axis = 0)
            a_ar_data[cl][cr] = np.concatenate(a_ar_data[cl][cr], axis = 0)
            

        else:
            print(f"No data for class left {cl} and class right {cr}.")


#-----------------calculating the number of trials for each class pair----------------
n_values_sum = {cl: {cr: {'attleft': 0, 'attright': 0} for cr in classes_right} for cl in classes_left}


for cl in classes_left:
    for cr in classes_right:
        n_values_sum[cl][cr]['attleft'] = len(l_al_data[cl][cr]) 
        n_values_sum[cl][cr]['attright'] = len(l_ar_data[cl][cr]) 
        

In [None]:
bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']
sfreq = 500.0
ch_names = ['left_input', 'right_input', 'attention_layer']
ch_types = ['eeg', 'eeg', 'eeg']

freq_ranges = [(0, 4), (4, 8), (8, 13), (13, 30), (30, 80)]

lr_split_structure = {cl: {cr: {band:  {
    'left_attleft': [], 'right_attleft': [], 'left_attright': [], 'right_attright': []
    } for band in bands} for cr in classes_right} for cl in classes_left} 


gr_means = copy.deepcopy(lr_split_structure)


start_time = time.time()
for cl in classes_left:
    gr_start = time.time()
    for cr in classes_right: 

        if n_values_sum[cl][cr]['attleft'] == 0:  ##this makes sure you only calulate for left class pairs that have trials
            continue                    
        
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)

        data= np.stack([
            l_al_data[cl][cr], 
            r_al_data[cl][cr], 
            a_al_data[cl][cr],
        ], axis=1) 
        print(data.shape)

        epochs = mne.EpochsArray(
            data,
            info, 
            tmin=0, 
            baseline=None
        )

            # Define channel indices for Granger causality
        seeds = np.array([[0], [1]])
        targets = np.array([[2], [2]])
        indices = (seeds, targets)

        # Compute Granger causality
        gc = spectral_connectivity_epochs(
            epochs, method='gc', indices=indices, sfreq=sfreq,
            fmin=0.5, fmax=40.0, tmin=0.0, tmax=(250 - 1) / sfreq, gc_n_lags=5
        )

    
        # Extract Granger causality data
        gc_data = gc.get_data()
        freq = gc.freqs

        for band, (f_min, f_max) in zip(bands, freq_ranges):
            band_idx = (freq >= f_min) & (freq < f_max)
            gr_means[cl][cr][band]['left_attleft'].append(
                np.mean(left_in_gr_leftatt[band_idx]))  #taking the mean across trials
            gr_means[cl][cr][band]['right_attleft'].append(
                np.mean(right_in_gr_leftatt[band_idx]))

start_time = time.time()
for cl in classes_left:
    gr_start = time.time()
    for cr in classes_right: 

        if n_values_sum[cl][cr]['attright'] == 0:  ##this makes sure you only calulate for left class pairs that have trials
            continue                    
        
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)

        data= np.stack([
            l_ar_data[cl][cr], 
            r_ar_data[cl][cr], 
            a_ar_data[cl][cr],
        ], axis=1) 
        print(data.shape)

        epochs = mne.EpochsArray(
            data,
            info, 
            tmin=0, 
            baseline=None
        )

            # Define channel indices for Granger causality
        seeds = np.array([[0], [1]])
        targets = np.array([[2], [2]])
        indices = (seeds, targets)

        # Compute Granger causality
        gc = spectral_connectivity_epochs(
            epochs, method='gc', indices=indices, sfreq=sfreq,
            fmin=0.5, fmax=40.0, tmin=0.0, tmax=(250 - 1) / sfreq, gc_n_lags=5
        )

    
        # Extract Granger causality data
        gc_data = gc.get_data()
        freqs = gc.freqs

        granger_plots_right[cl][cr].append(gc_data)

        for band, (f_min, f_max) in zip(bands, freq_ranges):
            band_idx = (freq >= f_min) & (freq < f_max)
            coh_means[cl][cr][band]['left_attleft'].append(
                np.mean(left_in_coh_leftatt[band_idx]))  #taking the mean across trials
            coh_means[cl][cr][band]['right_attleft'].append(
                np.mean(right_in_coh_leftatt[band_idx]))
    

  


In [None]:
#-------------coherence calculation---------------- this of for all the data available (different numbers of trials)-----------
dt = 0.002
bands = ['delta', 'theta', 'alpha', 'beta', 'gamma']   #for coherence I have looked at the mean valuee for these bands
freq_ranges = [(0.5, 4), (4, 8), (8, 13), (13, 30), (30, 80)]  # Actual frequency ranges


lr_split_structure = {cl: {cr: {band:  {
    'left_attleft': [], 'right_attleft': [], 'left_attright': [], 'right_attright': []
    } for band in bands} for cr in classes_right} for cl in classes_left} 


coh_means = copy.deepcopy(lr_split_structure)
coh_stdevs = copy.deepcopy(lr_split_structure)

start_time = time.time()
for cl in classes_left:
    coh_start = time.time()
    for cr in classes_right: 

        if n_values_sum[cl][cr]['attleft'] == 0:  ##this makes sure you only calulate for left class pairs that have trials
            continue                    
        
        #for attention left condition:
        left_in_coh_leftatt, freq = field_field_coherence(
            l_al_data[cl][cr],
            a_al_data[cl][cr],
            dt
        )
        right_in_coh_leftatt, freq = field_field_coherence(
            r_al_data[cl][cr],
            a_al_data[cl][cr],
            dt
        )
        
        #-----taking the mean value of coh for each band  | left condition
        for band, (f_min, f_max) in zip(bands, freq_ranges):
            band_idx = (freq >= f_min) & (freq < f_max)
            coh_means[cl][cr][band]['left_attleft'].append(
                np.mean(left_in_coh_leftatt[band_idx]))  #taking the mean across trials
            coh_means[cl][cr][band]['right_attleft'].append(
                np.mean(right_in_coh_leftatt[band_idx]))

            coh_stdevs[cl][cr][band]['left_attleft'].append(
                np.std(left_in_coh_leftatt[band_idx], ddof=1))
            coh_stdevs[cl][cr][band]['right_attleft'].append(
                np.std(right_in_coh_leftatt[band_idx], ddof=1))
            


for cl in classes_left:
    for cr in classes_right: 
        if n_values_sum[cl][cr]['attright'] == 0:  #right class pairs that have trials
            print(f"[SKIP] {cl}-{cr}: no attleft trials")
            continue        

        #for attention right condition:
        left_in_coh_rightatt, freq = field_field_coherence(
            l_ar_data[cl][cr],
            a_ar_data[cl][cr],
            dt
        )
        right_in_coh_rightatt, freq = field_field_coherence(
            r_ar_data[cl][cr],
            a_ar_data[cl][cr],
            dt
        )   

        if left_in_coh_leftatt.size == 0 or right_in_coh_leftatt.size == 0:
            print(f"[WARN] empty coherence array for {cl}-{cr} under attleft")


        #taking the mean value of coh for each band  | right condition
        for band, (f_min, f_max) in zip(bands, freq_ranges):
            band_idx = (freq >= f_min) & (freq < f_max)
            coh_means[cl][cr][band]['left_attright'].append(
                np.mean(left_in_coh_rightatt[band_idx]))  #taking the mean across trials
            coh_means[cl][cr][band]['right_attright'].append(
                np.mean(right_in_coh_rightatt[band_idx]))

            coh_stdevs[cl][cr][band]['left_attright'].append(
                np.std(left_in_coh_rightatt[band_idx], ddof=1))
            coh_stdevs[cl][cr][band]['right_attright'].append(
                np.std(right_in_coh_rightatt[band_idx], ddof=1))




    print(f"Class pair {cl} {cr} processed in {time.time() - coh_start:.2f} seconds")
         
run_time = time.time() - start_time
print(f"Total run time: {run_time:.2f} seconds")
