In [8]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import os
from utils import check_paths
import pandas as pd
from scipy.io import loadmat
import joblib
import matplotlib.gridspec as gridspec

from pactools import Comodulogram, REFERENCES, raw_to_mask

from mne.channels.layout import find_layout
from functools import partial
from mne.defaults import _handle_default

from mne.viz.topo import _erfimage_imshow_unified, _plot_topo

from mne.viz.utils import (
    _setup_vmin_vmax,
    add_background_image
)
from collections import namedtuple


In [9]:
def _erfimage_imshow(
    ax,
    ch_idx,
    tmin,
    tmax,
    vmin,
    vmax,
    low_fq_range,
    high_fq_range,
    ylim=None,
    data=None,
    epochs=None,
    sigma=None,
    order=None,
    scalings=None,
    vline=None,
    x_label=None,
    y_label=None,
    colorbar=False,
    cmap="RdBu_r",
    vlim_array=None,
):
    """Plot erfimage on sensor topography."""
    import matplotlib.pyplot as plt

    this_data = data[:, ch_idx, :]
    if vlim_array is not None:
        vmin, vmax = vlim_array[ch_idx]

    if callable(order):
        order = order(epochs.times, this_data)

    if order is not None:
        this_data = this_data[order]

    # if sigma > 0.0:
    #     this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma, axis=0)
    img = ax.imshow(
        this_data,
        # extent=[tmin, tmax, 0, len(data)],
        aspect="auto",
        origin="lower",
        vmin=0,
        vmax=0.95*this_data.max(),
        picker=True,
        cmap=cmap,
        interpolation="nearest",
    )
    m, n = this_data.shape

    n_xticks = 5
    xtick_step = n / n_xticks

    xtick_positions = np.arange(-.5, m + 1, xtick_step)
    xtick_labels = np.linspace(low_fq_range[0], low_fq_range[-1], len(xtick_positions))
    ax.set_xticks(xtick_positions)
    ax.set_xticklabels([f'{num :.1f}' for num in xtick_labels])

    n_yticks = 5
    ytick_step = m / n_yticks

    ytick_positions = np.arange(-.5, n + 1, ytick_step)
    ytick_labels = np.linspace(high_fq_range[0], high_fq_range[-1], len(ytick_positions))
    ax.set_yticks(ytick_positions)
    ax.set_yticklabels([f'{num :.1f}' for num in ytick_labels])

    ax = plt.gca()
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if colorbar:
        plt.colorbar(mappable=img)


def plot_comodulogram(
        data: np.ndarray,
        low_fq_range: np.ndarray,
        high_fq_range: np.ndarray,
        info: mne.Info,
        ch_type: str = 'eeg',
        scalings: dict[str, float] = None,
        title: str = None,
        colorbar: bool = True,
        vmin: float = None,
        vmax: float = None,
        cmap: str = 'viridis',
        ax: plt.Axes = None,
        layout_scale: float = 1,
        fig_facecolor: str = 'w',
        font_color: str = 'k',
        fig_background: plt.Figure = None
) -> plt.Figure:
    
    data = np.transpose(data, (2, 1, 0))
    layout = find_layout(info)
    scalings = _handle_default("scalings", scalings)
    scale_coeffs = [scalings.get(ch_type, 1)]

    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
    epochs_plug = namedtuple('EpochsPlug', ['events'])([None])

    show_func = partial(
        _erfimage_imshow_unified,
        scalings=scale_coeffs,
        data=data,
        epochs=epochs_plug,
        sigma=0,
        cmap=cmap
    )
    click_funk = partial(
        _erfimage_imshow,
        low_fq_range=low_fq_range,
        high_fq_range=high_fq_range,
        vmin=None, vmax=None,
        scalings=scale_coeffs,
        data=data,
        epochs=epochs_plug,
        sigma=0,
        cmap=cmap,
        vlim_array=None,
        colorbar=True,
    )

    fig = _plot_topo(
        info=info,
        times=[0, 1],
        show_func=show_func,
        click_func=click_funk,
        layout=layout,
        colorbar=colorbar,
        vmin=0,
        vmax=vmax,
        cmap=cmap,
        layout_scale=layout_scale,
        title=title,
        fig_facecolor=fig_facecolor,
        font_color=font_color,
        unified=True,
        img=True,
        axes=ax
    )
    add_background_image(fig, fig_background)

    return fig


def cosine_similarity_matrix(A: np.ndarray, B: np.ndarray) -> float:
    # Flatten matrices to 1D vectors
    A_flat = A.flatten()
    B_flat = B.flatten()
    
    # Compute cosine similarity
    numerator = np.dot(A_flat, B_flat)
    denominator = np.linalg.norm(A_flat) * np.linalg.norm(B_flat)
    
    return numerator / denominator if denominator != 0 else 0

In [45]:
eeg_data_dir = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set'
group_save_path = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set\\Y group'
sub_name = 'ALL_subs'
task = '_MAIN' # ['_BL', '_MAIN']
block_name = '_adaptation' # ['_baseline', '_adaptation']
task_stage = '_go' # '_plan' or '_go'
epochs_all_subs = mne.read_epochs(os.path.join(group_save_path, f"{sub_name}{task}_epochs{task_stage}{block_name}_ALL-epo.fif"), preload=True)
epochs_all_subs

Reading D:\BonoKat\research project\# study 1\eeg_data\set\Y group\ALL_subs_MAIN_epochs_go_adaptation_ALL-epo.fif ...
    Read a total of 1 projection items:
        Average EEG reference (1 x 60) active
    Found the data of interest:
        t =    -500.00 ...     700.00 ms
        0 CTF compensation matrices available
Adding metadata with 16 columns
2835 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated


Unnamed: 0,General,General.1
,Filename(s),ALL_subs_MAIN_epochs_go_adaptation_ALL-epo.fif
,MNE object type,EpochsFIF
,Measurement date,Unknown
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,2835
,Events counts,go_on: 2835
,Time range,-0.500 – 0.700 s
,Baseline,-0.500 – 0.700 s


In [46]:
pac_save_path = 'D:\\BonoKat\\research project\\motor_pac\\pac_results'

# Extract only EEG channel names from the Epochs object
eeg_channel_names = epochs_all_subs.copy().pick("eeg").ch_names
# print(eeg_channel_names)

# Select sensorimotor channels
choi = ['Fp1', 'Fp2', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'AF3', 'AF4', 'AF7', 'AF8',
        'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6']

epochs_all_subs.pick(choi)

times = epochs_all_subs.times

theta_range = np.linspace(4, 8, 20)  # Phase: 4-8 Hz
gamma_range = np.linspace(30, 80, 20)  # Amplitude: 30-80 Hz

In [47]:
estimator = Comodulogram(
    fs=epochs_all_subs.info['sfreq'],
    low_fq_range=theta_range,  # Phase frequencies (theta)
    high_fq_range=gamma_range, # Amplitude frequencies (gamma)
    method='tort',
    progress_bar=True
    )

all_channels_data = epochs_all_subs.get_data()
pac_results = np.empty(
    (len(theta_range), all_channels_data.shape[1], len(gamma_range))
)

for i, chan in enumerate(choi):
    data_flat = np.reshape(all_channels_data[:, i], -1)[None, :]
    pac = estimator.fit(
            data_flat,
            data_flat,
        )
    pac_results[:, i] = pac.comod_

    # Convert the plot to a Plotly figure (if supported)
    fig = pac.plot(tight_layout=False, cmap='magma')
    # Add a title
    plt.title(f"PAC MI {chan}: {task}{task_stage}{block_name}")

    # Save the plot
    plt.savefig(os.path.join(pac_save_path, f"pac_mi_group{task}{task_stage}{block_name}_{chan}.png"), dpi=300, bbox_inches="tight")
    plt.show()
    # plt.close()

np.save(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.npy"), pac_results)

[........................................] 100% | 25.26 sec | comodulogram: tort 
[........................................] 100% | 24.72 sec | comodulogram: tort 
[........................................] 100% | 24.79 sec | comodulogram: tort 
[........................................] 100% | 25.11 sec | comodulogram: tort 
[........................................] 100% | 24.77 sec | comodulogram: tort 
[........................................] 100% | 25.18 sec | comodulogram: tort 
[........................................] 100% | 24.77 sec | comodulogram: tort 
[........................................] 100% | 24.83 sec | comodulogram: tort 
[........................................] 100% | 25.31 sec | comodulogram: tort 
[........................................] 100% | 24.65 sec | comodulogram: tort 
[........................................] 100% | 24.80 sec | comodulogram: tort 
[........................................] 100% | 24.84 sec | comodulogram: tort 
[...............

  fig, axs = plt.subplots(n_lines, n_columns, figsize=figsize)


[........................................] 100% | 26.59 sec | comodulogram: tort 
[........................................] 100% | 25.21 sec | comodulogram: tort 
[........................................] 100% | 25.13 sec | comodulogram: tort 
[........................................] 100% | 24.71 sec | comodulogram: tort 
[........................................] 100% | 25.86 sec | comodulogram: tort 
[........................................] 100% | 24.61 sec | comodulogram: tort 
[........................................] 100% | 24.61 sec | comodulogram: tort 
[........................................] 100% | 24.64 sec | comodulogram: tort 
[........................................] 100% | 24.59 sec | comodulogram: tort 


In [48]:
# sanity check
a = np.load(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.npy"))

assert np.all(pac_results == a)

In [49]:
sim = np.empty((pac_results.shape[1], pac_results.shape[1]))
for i in range(pac_results.shape[1]):
    for j in range(pac_results.shape[1]):
        sim[i, j] = cosine_similarity_matrix(pac_results[:, i], pac_results[:, j])

%matplotlib qt
fig = plt.figure(figsize=(20, 15))
gs = gridspec.GridSpec(8, 8, figure=fig)

# Create axes with custom positions
ax1 = fig.add_subplot(gs[2:, :2])  # ax1 spans first 2 columns
ax2 = fig.add_subplot(gs[2:, 1:])  # ax2 spans last 2 columns


ax1.imshow(sim, cmap='magma')
ax1.set_xticks(np.arange(len(sim)))
ax1.set_yticks(np.arange(len(sim)))
ax1.set_xticklabels(
    epochs_all_subs.ch_names,
      rotation=45, ha="right",
      fontdict={'fontsize': 6}
    )  # Rotate for readability
ax1.set_yticklabels(epochs_all_subs.ch_names, fontdict={'fontsize': 8})

fig = plot_comodulogram(
    pac_results,
    theta_range,
    gamma_range,
    epochs_all_subs.info,
    layout_scale=1,
    ax=ax2,
    cmap='magma'
)

plt.savefig(os.path.join(pac_save_path, f"pac_mi_group_TOPO{task}{task_stage}{block_name}.png"), dpi=300, bbox_inches="tight")
plt.show()
