In [None]:

import mne
import numpy as np
from scipy.signal import hilbert
from mne_connectivity import spectral_connectivity_time
from scipy.io import loadmat
import pandas as pd
import os

In [None]:
# Define the folder where the subject-specific CSV files are stored
data_folder = r"W:\Projects\2019-04 M1M1PAS Project\analysis\source_Paolo\MNE_light"

# Get a list of all the CSV files for the subjects
file_list = [f for f in os.listdir(data_folder)]

# Initialize an empty list to store DataFrames
datapd = []
data = []

# Loop through each file and import the DataFrame
for file in file_list:
    # file = file_list[0]

    subject = file.split('_')[0].replace('sub-', '')
    task = file.split('_')[1].replace('task-', '')
    run = file.split('_')[2].replace('run-', '')

    file_path = os.path.join(data_folder, file)

    matlabdata = loadmat(file_path)

    leftSM = matlabdata['leftSM']
    rightSM = matlabdata['rightSM']

    # Create a new MNE RawArray with the raw data
    info = mne.create_info(ch_names=['leftSM', 'rightSM'], sfreq=1000, ch_types='eeg')

    source_data = np.vstack([leftSM, rightSM])

    source_raw = mne.io.RawArray(source_data, info)

    source_raw.resample(sfreq=200)

    # Epoch data
    epochs_raw = mne.make_fixed_length_epochs(source_raw, duration=1, preload=True)

    sfreq = epochs_raw.info['sfreq']
    # Define the frequency band for the analysis (Alpha: 8-13 Hz)
    fmin, fmax = 8., 13.
    freqs = np.linspace(fmin, fmax, int((fmax - fmin) * 1 + 1))

    # Calculate connectivity
    con = spectral_connectivity_time(
    epochs_raw,
    method=['ciplv'],
    sfreq=sfreq,
    freqs=freqs,
    fmin=fmin,
    fmax=fmax,
    faverage=True,  # Average connectivity across frequencies within the band
    mode='cwt_morlet',  # Use multitaper method,
    n_cycles = 5,
    average=False,
    n_jobs=1  # Number of parallel jobs to run (set according to your CPU)
    )
    fc = con.get_data()[:,2,0]

    # Find the median PLV across all windows
    threshold_fc = np.percentile(fc, 5)

    # Select windows where PLV > median PLV
    high_fc_windows = fc > threshold_fc

    # Band-pass filter between 8 and 13 Hz (alpha band)
    raw_filtered = source_raw.copy().filter(l_freq=fmin, h_freq=fmax)

    raw_analytic = raw_filtered.copy().apply_hilbert()

    # Epoch data
    epochs_analytic = mne.make_fixed_length_epochs(raw_analytic, duration=1, preload=True)

    # Get data for right
    rightSM_analytic = epochs_analytic.get_data(picks='rightSM')
    # Get data for left
    leftSM_analytic = epochs_analytic.get_data(picks='leftSM')

    rightSM_analytic = np.squeeze(rightSM_analytic)
    leftSM_analytic = np.squeeze(leftSM_analytic)

    # Compute signal envelopes using Hilbert transform
    rightSM_envelope = np.abs(rightSM_analytic)
    leftSM_envelope = np.abs(leftSM_analytic)

    # Calculate the medians of the envelopes
    right_median_win = np.median(rightSM_envelope, axis = 1)
    left_median_win = np.median(leftSM_envelope, axis = 1)

    # Calculate the medians of the envelopes
    right_threshold = np.percentile(right_median_win, 5)
    left_threshold = np.percentile(left_median_win, 5)

    # Select the lower median as the threshold median
    threshold_env = min(right_threshold, left_threshold)

    # Select windows where the envelope median exceeds the threshold
    high_env_windows = np.logical_and(left_median_win > threshold_env, right_median_win > threshold_env)

    selected_windows = np.logical_and(high_env_windows,high_fc_windows)

    right_analytic_selected = rightSM_analytic[selected_windows,:]
    left_analytic_selected = leftSM_analytic[selected_windows,:]

    right_analytic_selected_cat = np.concatenate(right_analytic_selected,axis=0)
    left_analytic_selected_cat = np.concatenate(left_analytic_selected,axis=0)

    # Extract phases
    right_phase = np.angle(right_analytic_selected_cat)
    left_phase = np.angle(left_analytic_selected_cat)

    # Calculate phase lags between C3 and C4
    phase_lags = np.unwrap(right_phase - left_phase)

    # # Collect the data
    # data.extend([(subject, task, run, phase_lag) for phase_lag in final_phase_lags])

    # # Convert the list to a dataframe
    # datapd = pd.DataFrame(data, columns=['subject','task', 'run', 'phase_lag'])
    datapd = pd.DataFrame(phase_lags, columns=['phase_lag'])

    # Save PLV and wPLI arrays in a single .npz file
    new_filename = f"{file.replace('_MNE.mat', '')}_phlags.csv"
    new_file_path = os.path.join(r"W:\Projects\2019-04 M1M1PAS Project\analysis\source_Paolo\fc_source\fc_source",new_filename)
    datapd.to_csv(new_file_path)


In [None]:
import matplotlib.pyplot as plt
plt.plot(np.squeeze(leftSM[0:1000]))
plt.plot(np.squeeze(rightSM[0:1000]))
plt.title('Raw')

In [None]:
left_resampled = source_raw.get_data(picks='leftSM')
right_resampled = source_raw.get_data(picks='rightSM')
plt.plot(np.squeeze(left_resampled[:,0:200]))
plt.plot(np.squeeze(right_resampled[:,0:200]))
plt.title('Downsampled')

In [None]:
left_filtered = raw_filtered.get_data(picks='leftSM')
right_filtered = raw_filtered.get_data(picks='rightSM')
plt.plot(np.squeeze(left_filtered[:,0:200]))
plt.plot(np.squeeze(right_filtered[:,0:200]))
plt.title('BP-filtered')

In [None]:
plt.plot(np.squeeze(np.real(leftSM_analytic[0,:])))
plt.plot(np.squeeze(np.real(rightSM_analytic[0,:])))
plt.title('Hilberted')

In [None]:
plt.plot(np.squeeze(left_phase[0:200]))
plt.plot(np.squeeze(right_phase[0:200]))
plt.plot(np.squeeze(phase_lags_normalized[0:200]))
plt.title('Phases and phase lag')

In [None]:
phase_lags_normalized = phase_lags % (2 * np.pi)
plt.hist(phase_lags_normalized)

In [None]:
n_bins = 30

fig, axes = plt.subplots(1, 1, subplot_kw={'polar': True}, figsize=(12, 8))

# Normalize phase lags to be between 0 and 2*pi
phase_lag_normalized = phase_lags % (2 * np.pi)

# Compute the histogram
counts, bin_edges = np.histogram(phase_lag_normalized, bins=n_bins, range=(0, 2 * np.pi), density=True)

# Convert bin edges to bin centers for polar plotting
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
width = bin_edges[1] - bin_edges[0]  # Calculate the width

# Plot the histogram in the current subplot
ax = axes
ax.bar(bin_centers, counts, width=width, align='center')

# Set title and labels
ax.set_title(f'Sub: {subject}, Task: {task}, Run: {run}', va='bottom')
ax.set_theta_zero_location('E')  # Set 0 radians (north) on the top
ax.set_theta_direction(-1)  # Set the direction of theta to go clockwise

# Compute the circular mean
# circular_mean = stats.circmean(phase_lag_normalized, high=2*np.pi, low=0)

# Plot the median line
# ax.plot([circular_mean, circular_mean], [0, max(counts)], color='red', linewidth=1)