In [None]:
import os
import numpy as np
import scipy.io
import mne
from mne.time_frequency import tfr_multitaper
from sklearn.utils import resample
import matplotlib.pyplot as plt

In [None]:
# Set directories
data_directory = 'E:\Research\EEG_Data\preprocessed_EEG_data\pre_iso_ano' 
files = [f for f in os.listdir(data_directory) if f.startswith('data_final') and f.endswith('.mat')]
nsubj = len(files)
data_list = []

def load_data(file_path):
    mat = scipy.io.loadmat(file_path)
    return mat['data_final']

# Function to compute ITC
def compute_itc(tfr_data):
    F = tfr_data.data  # shape (n_epochs, n_channels, n_freqs, n_times)
    N = F.shape[0]  # number of trials
    itc = np.abs(np.sum(F / np.abs(F), axis=0)) / N
    return itc ** 2 * N  # ITPCz calculation

In [None]:
# Process each subject's data
for file in files:
    data_final = load_data(os.path.join(data_directory, file))
    
    # Extract trial indices based on conditions
    trials_uni_iso = np.where(data_final['trialinfo'][:, 0] < 31)[0]
    trials_non_iso = np.where((data_final['trialinfo'][:, 0] > 30) & (data_final['trialinfo'][:, 0] < 61))[0]
    trials_uni_ana = np.where((data_final['trialinfo'][:, 0] > 60) & (data_final['trialinfo'][:, 0] < 91))[0]
    trials_non_ana = np.where((data_final['trialinfo'][:, 0] > 90) & (data_final['trialinfo'][:, 0] < 121))[0]
    
    # Define channel names and create MNE Info structure
    ch_names = [str(ch[0]) for ch in data_final['label'][0]]
    sfreq = 1000  # Assuming a sampling frequency of 1000 Hz
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')

    # Create MNE Raw objects for each condition
    uni_iso_raw = mne.io.RawArray(data_final['data'][:, trials_uni_iso], info)
    non_iso_raw = mne.io.RawArray(data_final['data'][:, trials_non_iso], info)
    uni_ana_raw = mne.io.RawArray(data_final['data'][:, trials_uni_ana], info)
    non_ana_raw = mne.io.RawArray(data_final['data'][:, trials_non_ana], info)

In [None]:
    
    # Frequency analysis parameters
    freqs = np.arange(0.25, 8.25, 0.25)
    n_cycles = 3 / freqs
    
    # Compute time-frequency representation (TFR) using multitaper method
    uni_iso_tfr = tfr_multitaper(uni_iso_raw, freqs=freqs, n_cycles=n_cycles, time_bandwidth=2.0, return_itc=False)
    non_iso_tfr = tfr_multitaper(non_iso_raw, freqs=freqs, n_cycles=n_cycles, time_bandwidth=2.0, return_itc=False)
    uni_ana_tfr = tfr_multitaper(uni_ana_raw, freqs=freqs, n_cycles=n_cycles, time_bandwidth=2.0, return_itc=False)
    non_ana_tfr = tfr_multitaper(non_ana_raw, freqs=freqs, n_cycles=n_cycles, time_bandwidth=2.0, return_itc=False)
    
    # Baseline correction (logratio)
    baseline = (-3, 0)
    uni_iso_tfr.apply_baseline(baseline, mode='logratio')
    non_iso_tfr.apply_baseline(baseline, mode='logratio')
    uni_ana_tfr.apply_baseline(baseline, mode='logratio')
    non_ana_tfr.apply_baseline(baseline, mode='logratio')

In [None]:
    
    # Compute ITC
    uni_iso_itc = compute_itc(uni_iso_tfr)
    non_iso_itc = compute_itc(non_iso_tfr)
    uni_ana_itc = compute_itc(uni_ana_tfr)
    non_ana_itc = compute_itc(non_ana_tfr)
    
    # Save ITC data for each condition
    data_list.append((uni_iso_itc, non_iso_itc, uni_ana_itc, non_ana_itc))

# Compute the grand average ITC
def grand_average(itc_list):
    return np.mean(np.array(itc_list), axis=0)

uni_iso_itc_avg = grand_average([itc[0] for itc in data_list])
non_iso_itc_avg = grand_average([itc[1] for itc in data_list])
uni_ana_itc_avg = grand_average([itc[2] for itc in data_list])
non_ana_itc_avg = grand_average([itc[3] for itc in data_list])

In [None]:
# Statistical analysis (permutation test)
def permutation_test(data1, data2, n_permutations=2000, alpha=0.05):
    observed_diff = np.mean(data1) - np.mean(data2)
    permutation_diffs = []
    
    combined = np.concatenate((data1, data2))
    for _ in range(n_permutations):
        permuted = resample(combined, replace=False)
        perm_data1 = permuted[:len(data1)]
        perm_data2 = permuted[len(data1):]
        permutation_diffs.append(np.mean(perm_data1) - np.mean(perm_data2))
    
    p_value = np.sum(np.abs(permutation_diffs) >= np.abs(observed_diff)) / n_permutations
    return observed_diff, p_value

# Run statistical test
observed_diff, p_value = permutation_test(uni_iso_itc_avg, non_iso_itc_avg)
print(f'Observed difference: {observed_diff:.4f}, p-value: {p_value:.4f}')

In [None]:
# Visualize ITC data
def plot_itc(itc_data, freqs, times, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(itc_data, aspect='auto', extent=[times[0], times[-1], freqs[0], freqs[-1]], origin='lower', cmap='viridis')
    plt.colorbar(label='ITC')
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    plt.show()

# Plot ITC for each condition
plot_itc(uni_iso_itc_avg, freqs, uni_iso_tfr.times, 'ITC for Uni Iso Trials')
plot_itc(non_iso_itc_avg, freqs, non_iso_tfr.times, 'ITC for Non Iso Trials')
plot_itc(uni_ana_itc_avg, freqs, uni_ana_tfr.times, 'ITC for Uni Ana Trials')
plot_itc(non_ana_itc_avg, freqs, non_ana_tfr.times, 'ITC for Non Ana Trials')