# PAC MAIN

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import os
from utils import check_paths
import pandas as pd
from scipy.io import loadmat
import joblib
import matplotlib.gridspec as gridspec

from pactools import Comodulogram, REFERENCES, raw_to_mask

from mne.channels.layout import find_layout
from functools import partial
from mne.defaults import _handle_default

from mne.viz.topo import _erfimage_imshow_unified, _plot_topo

from mne.viz.utils import (
    _setup_vmin_vmax,
    add_background_image
)
from collections import namedtuple

%matplotlib qt

In [4]:
def _erfimage_imshow(
    ax,
    ch_idx,
    tmin,
    tmax,
    vmin,
    vmax,
    low_fq_range,
    high_fq_range,
    ylim=None,
    data=None,
    epochs=None,
    sigma=None,
    order=None,
    scalings=None,
    vline=None,
    x_label=None,
    y_label=None,
    colorbar=False,
    cmap="RdBu_r",
    vlim_array=None,
):
    """Plot erfimage on sensor topography."""
    import matplotlib.pyplot as plt

    this_data = data[:, ch_idx, :]
    if vlim_array is not None:
        vmin, vmax = vlim_array[ch_idx]

    if callable(order):
        order = order(epochs.times, this_data)

    if order is not None:
        this_data = this_data[order]

    # if sigma > 0.0:
    #     this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma, axis=0)
    img = ax.imshow(
        this_data,
        # extent=[tmin, tmax, 0, len(data)],
        aspect="auto",
        origin="lower",
        vmin=0,
        vmax=0.95*this_data.max(),
        picker=True,
        cmap=cmap,
        interpolation="nearest",
    )
    m, n = this_data.shape

    n_xticks = 5
    xtick_step = n / n_xticks

    xtick_positions = np.arange(-.5, m + 1, xtick_step)
    xtick_labels = np.linspace(low_fq_range[0], low_fq_range[-1], len(xtick_positions))
    ax.set_xticks(xtick_positions)
    ax.set_xticklabels([f'{num :.1f}' for num in xtick_labels])

    n_yticks = 5
    ytick_step = m / n_yticks

    ytick_positions = np.arange(-.5, n + 1, ytick_step)
    ytick_labels = np.linspace(high_fq_range[0], high_fq_range[-1], len(ytick_positions))
    ax.set_yticks(ytick_positions)
    ax.set_yticklabels([f'{num :.1f}' for num in ytick_labels])

    ax = plt.gca()
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if colorbar:
        plt.colorbar(mappable=img)


def plot_comodulogram(
        data: np.ndarray,
        low_fq_range: np.ndarray,
        high_fq_range: np.ndarray,
        info: mne.Info,
        ch_type: str = 'eeg',
        scalings: dict[str, float] = None,
        title: str = None,
        colorbar: bool = True,
        vmin: float = None,
        vmax: float = None,
        cmap: str = 'viridis',
        ax: plt.Axes = None,
        layout_scale: float = 1,
        fig_facecolor: str = 'w',
        font_color: str = 'k',
        fig_background: plt.Figure = None
) -> plt.Figure:
    
    data = np.transpose(data, (2, 1, 0))
    layout = find_layout(info)
    scalings = _handle_default("scalings", scalings)
    scale_coeffs = [scalings.get(ch_type, 1)]

    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
    epochs_plug = namedtuple('EpochsPlug', ['events'])([None])

    show_func = partial(
        _erfimage_imshow_unified,
        scalings=scale_coeffs,
        data=data,
        epochs=epochs_plug,
        sigma=0,
        cmap=cmap
    )
    click_funk = partial(
        _erfimage_imshow,
        low_fq_range=low_fq_range,
        high_fq_range=high_fq_range,
        vmin=None, vmax=None,
        scalings=scale_coeffs,
        data=data,
        epochs=epochs_plug,
        sigma=0,
        cmap=cmap,
        vlim_array=None,
        colorbar=True,
    )

    fig = _plot_topo(
        info=info,
        times=[0, 1],
        show_func=show_func,
        click_func=click_funk,
        layout=layout,
        colorbar=colorbar,
        vmin=0,
        vmax=vmax,
        cmap=cmap,
        layout_scale=layout_scale,
        title=title,
        fig_facecolor=fig_facecolor,
        font_color=font_color,
        unified=True,
        img=True,
        axes=ax
    )
    add_background_image(fig, fig_background)

    return fig


def cosine_similarity_matrix(A: np.ndarray, B: np.ndarray) -> float:
    # Flatten matrices to 1D vectors
    A_flat = A.flatten()
    B_flat = B.flatten()
    
    # Compute cosine similarity
    numerator = np.dot(A_flat, B_flat)
    denominator = np.linalg.norm(A_flat) * np.linalg.norm(B_flat)
    
    return numerator / denominator if denominator != 0 else 0

# INDIVIDUAL SUBS

In [11]:
eeg_data_dir = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set'
group = 'O'
subs = os.listdir(os.path.join(eeg_data_dir, group))
tasks = ['_MAIN'] # ['_BL', '_MAIN']
task_stages = ['_plan', '_go'] # '_plan' or '_go'
block_names = ['_baseline', '_adaptation'] # ['_baseline', '_adaptation']


In [6]:
# Select sensorimotor channels
# choi = ['Fp1', 'Fp2',
#         'F1', 'F2', 'F3', 'F4', 'F5', 'F6',
#         'AF3', 'AF4', 'AF7', 'AF8',
#         'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6',
#         'C1', 'C2', 'C3', 'C4', 'C5', 'C6',
#         'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6']

theta_range = np.linspace(4, 8, 20)  # Phase: 4-8 Hz
gamma_range = np.linspace(30, 80, 20)  # Amplitude: 30-80 Hz

In [None]:
for sub_name in subs:
    epochs_path = os.path.join(eeg_data_dir, group, sub_name, 'preproc', 'analysis') 
    pac_save_path = os.path.join(eeg_data_dir, group, sub_name, 'pac_results')
    check_paths(pac_save_path)

    for task in tasks:
        for task_stage in task_stages: # task_stages
            if task == '_MAIN':
                for block_name in block_names: # block_names
                    epochs = mne.read_epochs(os.path.join(epochs_path, f"{sub_name}{task}_epochs{task_stage}{block_name}-epo.fif"), preload=True)
                    #! Change baseline for prestim period
                    epochs.apply_baseline(baseline=(-0.5, -0.001))
                    eeg_channel_names = epochs.copy().pick("eeg").ch_names
                    epochs.pick(eeg_channel_names)
                    # Crop epochs to the time window of interest
                    epo_tmin = 0.0
                    epo_tmax = [0.495 if task_stage == '_plan' else 0.695][0]
                    epochs.crop(tmin=epo_tmin, tmax=epo_tmax)
                    # epochs.pick(choi)
                    times = epochs.times

                    #Estimate PAC
                    estimator = Comodulogram(
                        fs=epochs.info['sfreq'],
                        low_fq_range=theta_range,  # Phase frequencies (theta)
                        high_fq_range=gamma_range, # Amplitude frequencies (gamma)
                        method='tort',
                        progress_bar=True
                        )

                    all_channels_data = epochs.get_data()
                    pac_results = np.empty(
                        (len(theta_range), all_channels_data.shape[1], len(gamma_range))
                    )

                    for i, chan in enumerate(eeg_channel_names): # enumerate(choi)
                        data_flat = np.reshape(all_channels_data[:, i], -1)[None, :]
                        pac = estimator.fit(
                                data_flat,
                                data_flat,
                            )
                        pac_results[:, i] = pac.comod_

                        # Convert the plot to a Plotly figure (if supported)
                        fig = pac.plot(tight_layout=False, cmap='magma')
                        # Add a title
                        plt.title(f"PAC MI {sub_name[-5:]} {chan}: {task}{task_stage}{block_name}")

                        # Save the plot
                        plt.savefig(os.path.join(pac_save_path, f"pac_mi_{sub_name[-5:]}{task}{task_stage}{block_name}_{chan}.png"), dpi=300, bbox_inches="tight")
                        plt.show()
                        plt.close()

                    np.save(os.path.join(pac_save_path, f"pac_mi_TOPO_{sub_name[-5:]}{task}{task_stage}{block_name}.npy"), pac_results)
            
                    # create TOPOPLOT and CORRELATION MATRIX
                    sim = np.empty((pac_results.shape[1], pac_results.shape[1]))
                    for i in range(pac_results.shape[1]):
                        for j in range(pac_results.shape[1]):
                            sim[i, j] = cosine_similarity_matrix(pac_results[:, i], pac_results[:, j])

                    fig = plt.figure(figsize=(20, 15))
                    gs = gridspec.GridSpec(8, 8, figure=fig)

                    # Create axes with custom positions
                    ax1 = fig.add_subplot(gs[2:, :2])  # ax1 spans first 2 columns
                    ax2 = fig.add_subplot(gs[2:, 1:])  # ax2 spans last 2 columns


                    ax1.imshow(sim, cmap='magma')
                    ax1.set_xticks(np.arange(len(sim)))
                    ax1.set_yticks(np.arange(len(sim)))
                    ax1.set_xticklabels(
                        epochs.ch_names,
                        rotation=45, ha="right",
                        fontdict={'fontsize': 6}
                        )  # Rotate for readability
                    ax1.set_yticklabels(epochs.ch_names, fontdict={'fontsize': 8})

                    fig = plot_comodulogram(
                        pac_results,
                        theta_range,
                        gamma_range,
                        epochs.info,
                        layout_scale=1,
                        ax=ax2,
                        cmap='magma'
                    )

                    plt.savefig(os.path.join(pac_save_path, f"pac_mi_TOPO_{sub_name[-5:]}{task}{task_stage}{block_name}.png"), dpi=300, bbox_inches="tight")
                    plt.show()

            else:
                pass

____________________________

Dirty field

CHECKS

In [11]:
a = np.load(os.path.join(pac_save_path, f"pac_mi_TOPO_{sub_name[-5:]}{task}{task_stage}{block_name}.npy"))
a.shape

(20, 30, 20)

In [None]:
pac_results

array([[[2.32090225e-05, 2.13349209e-05, 3.23033359e-05, ...,
         4.42917478e-05, 4.21747327e-05, 4.31077835e-05],
        [3.45413623e-05, 3.21738261e-05, 2.65375498e-05, ...,
         8.49398789e-05, 7.12030111e-05, 6.26488199e-05],
        [4.04992213e-05, 5.62466212e-05, 6.58383975e-05, ...,
         5.49672044e-05, 5.25495663e-05, 5.24626994e-05],
        ...,
        [5.06771439e-05, 5.61914250e-05, 5.99135174e-05, ...,
         1.11445139e-04, 1.21047797e-04, 1.17070554e-04],
        [5.82659791e-05, 5.30875374e-05, 4.43093153e-05, ...,
         6.55055594e-05, 5.13592151e-05, 4.43226648e-05],
        [4.92264334e-05, 4.54157148e-05, 3.64349820e-05, ...,
         9.76887490e-05, 8.64158398e-05, 7.88152448e-05]],

       [[2.45604313e-05, 2.88280162e-05, 4.19383089e-05, ...,
         3.62134052e-05, 3.91387239e-05, 4.38062768e-05],
        [2.78919565e-05, 3.35903499e-05, 3.33803387e-05, ...,
         4.87981507e-05, 4.32171190e-05, 4.41916531e-05],
        [5.23699596e-05, 

: 

____________________
SKIP - not relevant for subsequent analysis

ALL EPOCHS FROM ALL SUBS

In [None]:
# ALL EPOCHS ALL SUBS
eeg_data_dir = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set'
group = 'O'
sub_name = 'ALL_subs'
task = '_MAIN' # ['_BL', '_MAIN']
task_stage = '_plan' # '_plan' or '_go'
block_name = '_baseline' # ['_baseline', '_adaptation']

group_save_path = os.path.join(eeg_data_dir, f'{group} group')
epochs_all_subs = mne.read_epochs(os.path.join(group_save_path, f"{group}_{sub_name}{task}_epochs{task_stage}{block_name}_ALL-epo.fif"), preload=True)
epochs_all_subs
pac_save_path = os.path.join(group_save_path, 'pac_results')
check_paths(pac_save_path)

# Extract only EEG channel names from the Epochs object
eeg_channel_names = epochs_all_subs.copy().pick("eeg").ch_names
# print(eeg_channel_names)

# Select sensorimotor channels
choi = ['Fp1', 'Fp2',
        'F1', 'F2', 'F3', 'F4', 'F5', 'F6',
        'AF3', 'AF4', 'AF7', 'AF8',
        'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6',
        'C1', 'C2', 'C3', 'C4', 'C5', 'C6',
        'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6']

epochs_all_subs.pick(choi)

times = epochs_all_subs.times

theta_range = np.linspace(4, 8, 20)  # Phase: 4-8 Hz
gamma_range = np.linspace(30, 80, 20)  # Amplitude: 30-80 Hz
estimator = Comodulogram(
    fs=epochs_all_subs.info['sfreq'],
    low_fq_range=theta_range,  # Phase frequencies (theta)
    high_fq_range=gamma_range, # Amplitude frequencies (gamma)
    method='tort',
    progress_bar=True
    )

all_channels_data = epochs_all_subs.get_data()
pac_results = np.empty(
    (len(theta_range), all_channels_data.shape[1], len(gamma_range))
)

for i, chan in enumerate(choi):
    data_flat = np.reshape(all_channels_data[:, i], -1)[None, :]
    pac = estimator.fit(
            data_flat,
            data_flat,
        )
    pac_results[:, i] = pac.comod_

    # Convert the plot to a Plotly figure (if supported)
    fig = pac.plot(tight_layout=False, cmap='magma')
    # Add a title
    plt.title(f"PAC MI {chan}: {task}{task_stage}{block_name}")

    # Save the plot
    plt.savefig(os.path.join(pac_save_path, f"pac_mi_group{task}{task_stage}{block_name}_{chan}.png"), dpi=300, bbox_inches="tight")
    plt.show()
    # plt.close()

np.save(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.npy"), pac_results)
# sanity check
a = np.load(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.npy"))

assert np.all(pac_results == a)
sim = np.empty((pac_results.shape[1], pac_results.shape[1]))
for i in range(pac_results.shape[1]):
    for j in range(pac_results.shape[1]):
        sim[i, j] = cosine_similarity_matrix(pac_results[:, i], pac_results[:, j])

%matplotlib qt
fig = plt.figure(figsize=(20, 15))
gs = gridspec.GridSpec(8, 8, figure=fig)

# Create axes with custom positions
ax1 = fig.add_subplot(gs[2:, :2])  # ax1 spans first 2 columns
ax2 = fig.add_subplot(gs[2:, 1:])  # ax2 spans last 2 columns


ax1.imshow(sim, cmap='magma')
ax1.set_xticks(np.arange(len(sim)))
ax1.set_yticks(np.arange(len(sim)))
ax1.set_xticklabels(
    epochs_all_subs.ch_names,
      rotation=45, ha="right",
      fontdict={'fontsize': 6}
    )  # Rotate for readability
ax1.set_yticklabels(epochs_all_subs.ch_names, fontdict={'fontsize': 8})

fig = plot_comodulogram(
    pac_results,
    theta_range,
    gamma_range,
    epochs_all_subs.info,
    layout_scale=1,
    ax=ax2,
    cmap='magma'
)

plt.savefig(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.png"), dpi=300, bbox_inches="tight")
plt.show()
