In [None]:
# # Calculate phase lags on source signals 
# Things to check before running: 
# path to preprocessed EED data, original srate, downsampled srate, FOI, 
# epoch length (in both part 1 and 2), FC metric, N of cycles per wavelet, FC threshold, second downsampled srate, 
# Power threshold, name and path of the results file, whether save as one or many files

In [None]:
import mne
import numpy as np
from scipy.signal import hilbert
from mne_connectivity import spectral_connectivity_time
from scipy.io import loadmat
import pandas as pd
import os
rng = np.random.default_rng(42)  # set seed 

In [None]:
# Define the folder path for source EEG data
data_folder = r"W:\Projects\2019-04 M1M1PAS Project\analysis\source_signals\Beamformer"

# Get a list of all the EEG files
file_list = [f for f in os.listdir(data_folder)]

# Initialize an empty list to store results
data = []

# Loop through each file in the folder
for file in file_list:
    # file = file_list[0] # run only first file
    
    # Get meta data from the file name
    subject = file.split('_')[0].replace('sub-', '')
    task = file.split('_')[1].replace('task-', '')
    run = file.split('_')[2].replace('run-', '')
    
    # Load current file
    file_path = os.path.join(data_folder, file)
    matlabdata = loadmat(file_path)
    
    # Extract source time courses
    leftSM = matlabdata['leftSM']
    rightSM = matlabdata['rightSM']

    # Create a new MNE RawArray with the raw data
    info = mne.create_info(ch_names=['leftSM', 'rightSM'], sfreq=1000, ch_types='eeg')
    source_data = np.vstack([leftSM, rightSM])
    source_raw = mne.io.RawArray(source_data, info)
    
    # Downsample from 1k to 200Hz
    source_raw.resample(sfreq=200)
    
    # Part 1: FC threshold
    # Epoch data
    epochs_raw = mne.make_fixed_length_epochs(source_raw, duration=2, preload=True)

    # Define FOI 8-13 Hz
    fmin, fmax = 8., 13.
    freqs = np.linspace(fmin, fmax, int((fmax - fmin) * 1 + 1))

    # Calculate FC
    con = spectral_connectivity_time(
    epochs_raw,
    method=['wpli'], # Select wPLI as a metric
    freqs=freqs,
    faverage=True,  # Average connectivity across frequencies within the band
    mode='cwt_morlet',  # Use Morlet wavelet
    n_cycles = 5, # 5 cycles in the wavelet
    average=False, # Retain epochs
    n_jobs=1  
    )
    fc = con.get_data()[:,2,0]

    # Find the median FC across all windows
    threshold_fc = np.percentile(fc, 50)

    # Select windows where FC > median FC
    high_fc_windows = fc > threshold_fc
    
    # Part 2: Power threshold
    # Band-pass filter 8-13 Hz
    raw_filtered = source_raw.copy().filter(l_freq=fmin, h_freq=fmax)
    
    # Downsample to 50 Hz
    raw_filtered_resampled = raw_filtered.resample(sfreq=50)
    
    # Hilbert transform
    raw_analytic = raw_filtered_resampled.copy().apply_hilbert()

    # Epoch data
    epochs_analytic = mne.make_fixed_length_epochs(raw_analytic, duration=2, preload=True)

    # Separate data for right and left M1
    rightSM_analytic = epochs_analytic.get_data(picks='rightSM')
    leftSM_analytic = epochs_analytic.get_data(picks='leftSM')
    rightSM_analytic = np.squeeze(rightSM_analytic)
    leftSM_analytic = np.squeeze(leftSM_analytic)

    # Compute signal envelopes
    rightSM_envelope = np.abs(rightSM_analytic)
    leftSM_envelope = np.abs(leftSM_analytic)

    # Calculate the medians of the envelopes for each window
    right_median_win = np.median(rightSM_envelope, axis = 1)
    left_median_win = np.median(leftSM_envelope, axis = 1)

    # Calculate the median of window-medians
    right_threshold = np.percentile(right_median_win, 50)
    left_threshold = np.percentile(left_median_win, 50)

    # Select the lower median as the threshold
    threshold_env = min(right_threshold, left_threshold)

    # Select windows where the window-median exceeds the threshold
    high_env_windows = np.logical_and(left_median_win > threshold_env, right_median_win > threshold_env)
    
    # Select windows where both FC and Power thresholds are exceeded
    selected_windows = np.logical_and(high_env_windows,high_fc_windows)
    right_analytic_selected = rightSM_analytic[selected_windows,:]
    left_analytic_selected = leftSM_analytic[selected_windows,:]
    
    # Cancatenated retained windows
    right_analytic_selected_cat = np.concatenate(right_analytic_selected,axis=0)
    left_analytic_selected_cat = np.concatenate(left_analytic_selected,axis=0)
    
    # Part 3: Phase lags
    # Extract phases
    right_phase = np.angle(right_analytic_selected_cat)
    left_phase = np.angle(left_analytic_selected_cat)

    # Calculate phase lags between C3 and C4 and unwrap them
    phase_lags = np.unwrap(right_phase - left_phase)

    # Collect the data
    data.extend([(subject, task, run, phase_lag) for phase_lag in phase_lags])
    
# Convert the list to a dataframe
datapd = pd.DataFrame(data, columns=['Subject','Intervention', 'Time', 'Phase_lag'])

# Save results
new_filename = f"sourcePhLag.csv"
new_file_path = os.path.join(r"W:\Projects\2019-04 M1M1PAS Project\analysis\source_signals\fc_source",new_filename)
datapd.to_csv(new_file_path)


In [None]:

datapd