## PAC

## SOURCE-LEVEL PAC AND STATS

**This script:**
1. Creates theta-gamma PAC comodulograms for condition (all vertices) for each subject and saves PAC data as numpy array
2. Runs cluster-besed permutation test on PAC data

**OUTCOME: PAC estimates for all vertices for each subject and statistical assessment of the results**

In [82]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import os
from utils import check_paths
import pandas as pd
from scipy.io import loadmat
import joblib
import matplotlib.gridspec as gridspec

from pactools import Comodulogram, REFERENCES, raw_to_mask

from mne.channels.layout import find_layout
from functools import partial
from mne.defaults import _handle_default

from mne.viz.topo import _erfimage_imshow_unified, _plot_topo

from mne.viz.utils import (
    _setup_vmin_vmax,
    add_background_image
)
from collections import namedtuple

%matplotlib qt

1. PAC analysis per condition per subject

In [83]:
eeg_data_dir = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set'
group = 'Y'
subs = os.listdir(os.path.join(eeg_data_dir, group))
tasks = ['_MAIN'] # ['_BL', '_MAIN']
task_stages = ['_plan', '_go'] # ['_plan', '_go']
block_names = ['_baseline', '_adaptation'] # ['_baseline', '_adaptation']

theta_range = np.linspace(4, 8, 20)  # Phase: 4-8 Hz
gamma_range = np.linspace(30, 80, 20)  # Amplitude: 30-80 Hz


In [None]:
for sub_name in subs:
    print(f"Processing subject: {sub_name}")
    analysis_dir = os.path.join(eeg_data_dir, group, sub_name, 'preproc', 'analysis')
    pac_save_path = os.path.join(analysis_dir, 'source', 'PAC')
    check_paths(pac_save_path)

    for task in tasks:
        for task_stage in task_stages: # task_stages
            if task == '_MAIN':
                for block_name in block_names: # block_names
                        stcs_path = os.path.join(analysis_dir, 'source', 'morphed_stcs', task, task_stage, block_name) 

                        stcs = []
                        stcs_data = []

                        for stc_file in os.listdir(stcs_path):
                            if stc_file.endswith('-rh.stc'): # MNE will load both hemispheres anyway
                                stc_path = os.path.join(stcs_path, stc_file)
                                stc = mne.read_source_estimate(stc_path, subject=sub_name)
                                stc.crop(tmin=0.0, tmax=0.495)
                                stcs.append(stc)
                                stcs_data.append(stc.data)

                        source_array = np.stack(stcs_data, axis=0)
                        print(source_array.shape)  # (epochs x vertices x time)

                        times = stcs[0].times

                        #Estimate PAC
                        estimator = Comodulogram(
                            fs=stcs[0].sfreq,
                            low_fq_range=theta_range,  # Phase frequencies (theta)
                            high_fq_range=gamma_range, # Amplitude frequencies (gamma)
                            method='tort',
                            progress_bar=True
                            )

                        pac_results = np.empty(
                            (len(theta_range), source_array.shape[1], len(gamma_range))
                        )

                        for i in range(source_array.shape[1]):
                            print(f"Processing source {i+1}/{source_array.shape[1]}")

                            data_flat = np.reshape(source_array[:, i], -1)[None, :]
                            pac = estimator.fit(
                                    data_flat,
                                    data_flat,
                                )
                            pac_results[:, i] = pac.comod_

                            if i in range(10):
                                # Convert the plot to a Plotly figure (if supported)
                                fig = pac.plot(tight_layout=False, cmap='magma')
                                # Add a title
                                plt.title(f"PAC MI {sub_name[-5:]} - source={i}: {task}{task_stage}{block_name}")

                                # Save the plot
                                plt.show()

                        np.save(os.path.join(pac_save_path, f"PAC_MI_SOURCE_{sub_name[-5:]}{task}{task_stage}{block_name}.npy"), pac_results)
                else:
                    continue

Processing subject: s1_pac_sub01
(50, 5124, 249)
Processing source 1/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 2/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 3/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 4/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 5/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 6/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 7/5124
[........................................] 100% | 0.17 sec | comodulogram: tort 
Processing source 8/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 9/5124
[........................................] 100% | 0.15 sec | comodulogram: tor

  amplitude_dist * np.log(amplitude_dist * n_bins))
  amplitude_dist * np.log(amplitude_dist * n_bins))


[........................................] 100% | 0.17 sec | comodulogram: tort 
Processing source 2434/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 2435/5124
[........................................] 100% | 0.13 sec | comodulogram: tort 
Processing source 2436/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 2437/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 2438/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 2439/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 2440/5124
[........................................] 100% | 0.14 sec | comodulogram: tort 
Processing source 2441/5124
[........................................] 100% | 0.15 sec | comodulogram: tort 
Processing source 2442/5124
[..................

_________________

In [44]:
stc.vertices

[array([   0,    1,    2, ..., 2559, 2560, 2561]),
 array([   0,    1,    2, ..., 2559, 2560, 2561])]

In [None]:
# Plotting the first 10 vertices of the first epoch
plt.figure(figsize=(10, 6))

for i in range(10):
    plt.plot(times, source_array[0, i], label=f'Vertex {i}')


____________

In [None]:
for sub_name in subs:
    epochs_path = os.path.join(eeg_data_dir, group, sub_name, 'preproc', 'analysis') 
    pac_save_path = os.path.join(eeg_data_dir, group, sub_name, 'pac_results')
    check_paths(pac_save_path)

    for task in tasks:
        for task_stage in task_stages: # task_stages
            if task == '_MAIN':
                for block_name in block_names: # block_names
                    epochs = mne.read_epochs(os.path.join(epochs_path, f"{sub_name}{task}_epochs{task_stage}{block_name}-epo.fif"), preload=True)
                    eeg_channel_names = epochs.copy().pick("eeg").ch_names
                    epochs.pick(eeg_channel_names)
                    # Crop epochs to the time window of interest
                    epo_tmin = 0.0
                    epo_tmax = [0.495 if task_stage == '_plan' else 0.695][0]
                    epochs.crop(tmin=epo_tmin, tmax=epo_tmax)
                    # epochs.pick(choi)
                    times = epochs.times

                    #Estimate PAC
                    estimator = Comodulogram(
                        fs=epochs.info['sfreq'],
                        low_fq_range=theta_range,  # Phase frequencies (theta)
                        high_fq_range=gamma_range, # Amplitude frequencies (gamma)
                        method='tort',
                        progress_bar=True
                        )

                    all_channels_data = epochs.get_data()
                    pac_results = np.empty(
                        (len(theta_range), all_channels_data.shape[1], len(gamma_range))
                    )

                    for i, chan in enumerate(eeg_channel_names): # enumerate(choi)
                        data_flat = np.reshape(all_channels_data[:, i], -1)[None, :]
                        pac = estimator.fit(
                                data_flat,
                                data_flat,
                            )
                        pac_results[:, i] = pac.comod_

                        # Convert the plot to a Plotly figure (if supported)
                        fig = pac.plot(tight_layout=False, cmap='magma')
                        # Add a title
                        plt.title(f"PAC MI {sub_name[-5:]} {chan}: {task}{task_stage}{block_name}")

                        # Save the plot
                        plt.savefig(os.path.join(pac_save_path, f"pac_mi_{sub_name[-5:]}{task}{task_stage}{block_name}_{chan}.png"), dpi=300, bbox_inches="tight")
                        plt.show()
                        plt.close()

                    np.save(os.path.join(pac_save_path, f"pac_mi_TOPO_{sub_name[-5:]}{task}{task_stage}{block_name}.npy"), pac_results)
            
                    # create TOPOPLOT and CORRELATION MATRIX
                    sim = np.empty((pac_results.shape[1], pac_results.shape[1]))
                    for i in range(pac_results.shape[1]):
                        for j in range(pac_results.shape[1]):
                            sim[i, j] = cosine_similarity_matrix(pac_results[:, i], pac_results[:, j])

                    fig = plt.figure(figsize=(20, 15))
                    gs = gridspec.GridSpec(8, 8, figure=fig)

                    # Create axes with custom positions
                    ax1 = fig.add_subplot(gs[2:, :2])  # ax1 spans first 2 columns
                    ax2 = fig.add_subplot(gs[2:, 1:])  # ax2 spans last 2 columns


                    ax1.imshow(sim, cmap='magma')
                    ax1.set_xticks(np.arange(len(sim)))
                    ax1.set_yticks(np.arange(len(sim)))
                    ax1.set_xticklabels(
                        epochs.ch_names,
                        rotation=45, ha="right",
                        fontdict={'fontsize': 6}
                        )  # Rotate for readability
                    ax1.set_yticklabels(epochs.ch_names, fontdict={'fontsize': 8})

                    fig = plot_comodulogram(
                        pac_results,
                        theta_range,
                        gamma_range,
                        epochs.info,
                        layout_scale=1,
                        ax=ax2,
                        cmap='magma'
                    )

                    plt.savefig(os.path.join(pac_save_path, f"pac_mi_TOPO_{sub_name[-5:]}{task}{task_stage}{block_name}.png"), dpi=300, bbox_inches="tight")
                    plt.show()

            else:
                pass