## STATS: Compare conditions on sensor and source levels

In [2]:
import numpy as np
import os
from utils import check_paths

import mne
from mne.stats import permutation_cluster_1samp_test

import scipy
from scipy.stats import zscore
from scipy.sparse import coo_matrix, save_npz

import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib qt

**PLOTTING FUNCS**

In [2]:
def plot_rect_topo_from_epochs(data, epochs_info, cmap='YlGn', vmin=None, vmax=None, title=''):
    """
    Plot a rectangular grid topomap from per-channel data using layout from epochs.info.

    Parameters
    ----------
    data : array, shape (n_channels,)
        Scalar values per channel (e.g., PAC strength).
    epochs_info : instance of mne.Info
        Info object from Epochs to extract channel positions.
    cmap : str or Colormap
        Colormap to use for values.
    vmin, vmax : float
        Limits for color scaling.
    title : str
        Title for the plot.
    """
    # Get rectangular layout from MNE
    layout = mne.channels.make_eeg_layout(epochs_info)
    ch_names = layout.names
    pos_2d = layout.pos[:, :2]

    # Normalize positions to grid indices
    x_idx = np.round((pos_2d[:, 0] - np.min(pos_2d[:, 0])) / np.ptp(pos_2d[:, 0]) * 14).astype(int)
    y_idx = np.round((pos_2d[:, 1] - np.min(pos_2d[:, 1])) / np.ptp(pos_2d[:, 1]) * 14).astype(int)
    layout_grid = {ch: (y, x) for ch, x, y in zip(ch_names, x_idx, y_idx)}
    # name, pos in zip(layout.names, layout.pos)

    # Prepare grid size
    nrows = y_idx.max() + 1
    ncols = x_idx.max() + 1

    # Start plotting
    fig, ax = plt.subplots(figsize=(ncols, nrows))
    ax.set_xlim(0, ncols)
    ax.set_ylim(0, nrows)
    ax.invert_yaxis()
    ax.set_xticks([])
    ax.set_yticks([])

    norm = plt.Normalize(vmin if vmin is not None else np.nanmin(data),
                         vmax if vmax is not None else np.nanmax(data))
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

    for i, ch in enumerate(ch_names):
        # if ch not in layout_grid:
        #     continue
        row, col = layout_grid[ch]
        value = data[i].T
        color = sm.to_rgba(value)
        rect = patches.Rectangle((col, row), 1, 1, facecolor=color, edgecolor='black')
        ax.add_patch(rect)
        ax.text(col + 0.5, row + 0.5, ch, ha='center', va='center', fontsize=7)

    plt.gca().invert_yaxis()
    plt.colorbar(sm, ax=ax, shrink=0.8, label='Value')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

    return fig, ax


In [3]:
def plot_matrix_topo_from_epochs(data, epochs_info, cmap='YlGn', vmin=None, vmax=None, title=''):
    """
    Plot a topographic layout of 2D matrices (e.g., PAC) per channel using epochs.info.

    Parameters
    ----------
    data : ndarray, shape (n_channels, height, width)
        A 2D matrix per channel (e.g., PAC frequency x frequency).
    epochs_info : instance of mne.Info
        Info object to extract channel layout.
    cmap : str or Colormap
        Colormap to use.
    vmin, vmax : float or None
        Color scale limits.
    title : str
        Title for the entire plot.
    """
    n_channels, h, w = data.shape

    # Get layout info
    layout = mne.channels.make_eeg_layout(epochs_info)
    ch_names = layout.names
    pos_2d = layout.pos[:, :2]

    # Normalize positions to grid indices
    x_idx = np.round((pos_2d[:, 0] - np.min(pos_2d[:, 0])) / np.ptp(pos_2d[:, 0]) * 14).astype(int)
    y_idx = np.round((pos_2d[:, 1] - np.min(pos_2d[:, 1])) / np.ptp(pos_2d[:, 1]) * 14).astype(int)
    layout_grid = {ch: (y, x) for ch, x, y in zip(ch_names, x_idx, y_idx)}

    # Grid dimensions
    nrows = y_idx.max() + 1
    ncols = x_idx.max() + 1

    # Set up figure
    fig, ax = plt.subplots(figsize=(ncols, nrows))
    ax.set_xlim(0, ncols)
    ax.set_ylim(0, nrows)
    # ax.invert_yaxis()
    ax.set_xticks([])
    ax.set_yticks([])

    # Color normalization
    norm = plt.Normalize(vmin if vmin is not None else np.nanmin(data),
                         vmax if vmax is not None else np.nanmax(data))
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

    for i, ch in enumerate(ch_names):
        if ch not in layout_grid:
            continue
        row, col = layout_grid[ch]
        matrix = data[i].T

        # Plot small matrix inside rectangle
        extent = (col, col + 1, row, row + 1)
        ax.imshow(matrix, cmap=cmap, norm=norm, extent=extent, origin='lower', aspect='auto')

        # Draw frame and label
        rect = patches.Rectangle((col, row), 1, 1, fill=False, edgecolor='black', linewidth=0.5)
        ax.add_patch(rect)
        ax.text(col + 0.5, row + 0.5, ch, ha='center', va='center', fontsize=6, color='black')

    plt.colorbar(sm, ax=ax, shrink=0.7, label='Matrix Value')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

    return fig, ax


In [4]:
def plot_significant_topomap(T_obs, clusters, cluster_p_values, info, p_thresh=0.05, vlim=(-4, 4),
                             group=None, task=None, task_stage=None, block_name=None):
    """
    Plots a topomap highlighting electrodes in significant clusters.

    Parameters
    ----------
    T_obs : array, shape (n_channels,)
        Observed T-values.
    clusters : list of boolean arrays
        Cluster masks (n_channels,).
    cluster_p_values : array
        P-values for each cluster.
    info : instance of mne.Info
        EEG info with channel locations.
    p_thresh : float
        Significance threshold.
    title : str
        Title for the plot.
    """

    task  = ['' if task==None else task][0]
    task_stage = ['' if task_stage==None else task_stage][0]
    block_name = ['' if block_name==None else block_name][0]

    # Combine significant cluster masks
    sig_mask = np.zeros_like(T_obs, dtype=bool)
    for cluster, p_val in zip(clusters, cluster_p_values):
        if p_val <= p_thresh:
            sig_mask |= cluster

    # Get color limits manually
    if vlim is None:
        vlim = np.nanmax(np.abs(T_obs))

    title = f'{group}_{task}{task_stage}{block_name}: Significant Electrodes'
    # Start plotting
    fig, ax = plt.subplots()

    # Plot topomap with significant electrodes highlighted
    im, _ = mne.viz.plot_topomap(
        T_obs,
        info,
        cmap='PiYG',
        vlim=vlim,
        show=False,
        mask=sig_mask,
        mask_params=dict(marker='h', markersize=15, 
                         markerfacecolor='y',
                         markeredgecolor='k'),
        contours=0,
        axes=ax,
        sensors=False # Hide insignificant sensors
    )

    plt.colorbar(im, ax=ax, shrink=0.6, label='T-value')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

**TWO CONDITIONS**

In [44]:
# Conditions definition: name -> (task, stage, block)
conditions = {
    "BL_go": ("_BL", "_go", None),
    "BL_plan": ("_BL", "_plan", None),

    "MAIN_go_baseline": ("_MAIN", "_go", "_baseline"),
    "MAIN_plan_baseline": ("_MAIN", "_plan", "_baseline"),

    "MAIN_go_adaptation": ("_MAIN", "_go", "_adaptation"),
    "MAIN_plan_adaptation": ("_MAIN", "_plan", "_adaptation"),

    "MAIN_go_combined": ("_MAIN", "_go", None),     # both blocks
    "MAIN_plan_combined": ("_MAIN", "_plan", None)  # both blocks
}

comparisons = [
    ("BL_go", "BL_plan"),
    ("MAIN_go_baseline", "MAIN_plan_baseline"),
    ("MAIN_go_adaptation", "MAIN_plan_adaptation"),
    ("MAIN_go_combined", "MAIN_plan_combined"),
    ("MAIN_go_combined", "BL_go"),
    ("MAIN_plan_combined", "BL_plan")
]


In [6]:
def impute_sources(sub_name, pac_t):
    ### NAN Imputation for PAC Matrices ###
    pac_imputed = pac_t.copy()

    # Step 1: Detect vertices with any NaN in their 20×20 PAC
    nan_mask = np.isnan(pac_t).any(axis=(1, 2))  # shape (5124,)
    if nan_mask.any() == True:
        print(f"Found {nan_mask.sum()} vertices with NaN values.")

        # Step 2: Impute NaNs from neighbors
        for vtx in np.where(nan_mask)[0]:
            neighbors = adj[[vtx]].indices
            valid_neighbors = [n for n in neighbors if not nan_mask[n]]
            # Average PAC matrices from neighbors
            pac_imputed[vtx] = np.nanmean(pac_t[valid_neighbors], axis=0)

        print(f"NaN imputation for {sub_name} complete.")
    return pac_imputed

In [42]:
def load_condition_data(eeg_data_dir, subjects, condition_key, conditions, level):
    """Load PAC data for one condition, given subject-specific pac_dirs."""
    task, stage, block = conditions[condition_key]
    pac_list = []
    pac_zscore_list = []

    for sub_name in subjects:
        # print(f'Loading {sub_name} data for {task, stage, block}...')
        sub_dir = os.path.join(eeg_data_dir, sub_name)
        if level == 'sensors':
            pac_dir = os.path.join(sub_dir, "pac_results")
        else:
            pac_dir = os.path.join(sub_dir, "preproc", "analysis", 'source', 'PAC')

        if block is None and task == "_MAIN":
            # Combine both blocks for MAIN
            blocks = ["_baseline", "_adaptation"]
            block_data = []
            for b in blocks:
                if level == 'sensors':
                    pac = np.load(os.path.join(pac_dir, f"pac_mi_TOPO_{sub_name[-5:]}{task}{stage}{b}.npy"))
                else:
                    pac = np.load(os.path.join(pac_dir, f"PAC_MI_SOURCE_{sub_name[-5:]}{task}{stage}{b}.npy"))
                pac_t = np.transpose(pac, (1, 0, 2))
                # Imputation of PAC nan values on source level
                if level == 'sources':
                    pac_t = impute_sources(sub_name, pac_t)
                block_data.append(pac_t)
            blocks_arr = np.stack(block_data, axis=0)
            print(f'Blocks stacked shape: {blocks_arr.shape}')
            blocks_arr_ch_ave = np.mean(blocks_arr, axis=0) # average across blocks for each channel to keep data shape consistent

            pac_list.append(blocks_arr_ch_ave)
            blocks_arr_ch_ave_zscore = zscore(blocks_arr_ch_ave, axis=0, nan_policy='omit')
            pac_zscore_list.append(blocks_arr_ch_ave_zscore)

        else:
            # Single block or BL task
            bname = "" if block is None else block
            if level == 'sensors':
                pac = np.load(os.path.join(pac_dir, f"pac_mi_TOPO_{sub_name[-5:]}{task}{stage}{bname}.npy"))
            else:
                pac = np.load(os.path.join(pac_dir, f"PAC_MI_SOURCE_{sub_name[-5:]}{task}{stage}{bname}.npy"))
            pac_t = np.transpose(pac, (1, 0, 2))

            if level == 'sources':
                pac_t = impute_sources(sub_name, pac_t)

            pac_list.append(pac_t)
            pac_t_zscore = zscore(pac_t, axis=0, nan_policy='omit')
            pac_zscore_list.append(pac_t_zscore)

    # # Stack them along a new first axis (subject axis)
    pac_all = np.stack(pac_list, axis=0)

    # Z-score the PAC data across subjects
    pac_zscore_all = np.stack(pac_zscore_list, axis=0)
    # print('PAC array shape:', pac_all.shape) # (24, 60, 20, 20) subs x electrodes x ph_freqs x amp_freqs
    print('z-scored PAC array shape:', pac_zscore_all.shape)

    ### Global normalization for SOURCE LEVEL
    pac_all_ave = np.mean(pac_all, axis=(2, 3))
    pac_zscore_all_ave = (pac_all_ave - np.nanmean(pac_all_ave)) / np.nanstd(pac_all_ave)

    return [pac_zscore_all if level == 'sensors' else pac_zscore_all_ave][0]

def load_two_conditions(eeg_data_dir, subjects, cond1, cond2, conditions, level):
    """Load data for two conditions for all participants."""
    # print(f'Loading {cond1} vs {cond2}...')
    data1 = load_condition_data(eeg_data_dir, subjects, cond1, conditions, level=level)
    data2 = load_condition_data(eeg_data_dir, subjects, cond2, conditions, level=level)
    # print(data1.shape, data2.shape)
    return data1, data2

def iterate_comparisons(eeg_data_dir, subjects, conditions, comparisons, level):
    """Iterate through comparisons and yield loaded data."""
    for cond1, cond2 in comparisons:
        data1, data2 = load_two_conditions(eeg_data_dir, subjects, cond1, cond2, conditions, level=level)
        yield cond1, cond2, data1, data2


# SENSORS

In [21]:
group = 'Y'
eeg_data_dir = f'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set\\{group}'
subjects = os.listdir(eeg_data_dir)

# Create directories for saving stats
group_save_path = f'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set\\{group} group'
pac_stats_save_path = os.path.join(group_save_path, 'pac_stats', 'conditions')
check_paths(pac_stats_save_path)

# Create directories for saving figures
fig_group_path = os.path.join(group_save_path, 'pac_stats', 'figs')
fig_group_save_path = os.path.join(fig_group_path, group)
fig_task_save_path = os.path.join(fig_group_path, 'conditions')
check_paths(fig_task_save_path)

############### CREATE ADJACENCY MATRIX FOR STATISTICAL TEST #############
# find_ch_adjacency first attempts to find an existing "neighbor"
# (adjacency) file for given sensor layout.
# If such a file doesn't exist, an adjacency matrix is computed on the fly,
# using Delaunay triangulations.

# Load one epoch file to get info
epochs_path = os.path.join(eeg_data_dir, subjects[0], 'preproc', 'analysis')
epochs = mne.read_epochs(os.path.join(epochs_path, f"{subjects[0]}_BL_epochs_plan-epo.fif"), preload=True)
eeg_channel_names = epochs.copy().pick("eeg").ch_names
epochs.pick(eeg_channel_names)

# Create channel adjacency matrix
sensor_adjacency, ch_names = mne.channels.find_ch_adjacency(epochs.info, "eeg")
print(f'Adjacency matrix shape: {sensor_adjacency.shape}')

############### RUN STATISTICAL COMPARISON #############
for cond1, cond2, data1, data2 in iterate_comparisons(eeg_data_dir, subjects, conditions, comparisons, level='sensors'):
    print(f"Running stats for {cond1} vs {cond2}...")

    pac_zscore_diff = data1 - data2
    print('z-scored diff PAC array shape:', pac_zscore_diff.shape) # (24, 60, 20, 20) subs x electrodes x ph_freqs x amp_freqs

    # Averafe z-scored PAC over phase and amplitude frequencies
    pac_zscore_diff_ave = np.mean(pac_zscore_diff, axis=(2, 3)) # (24, 60) subs x electrodes
    print(pac_zscore_diff_ave.shape)

    # # Save the PAC data
    np.save(os.path.join(pac_stats_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_ZSCORE_freqs_ave.npy"), pac_zscore_diff_ave)

    ############# PLOT AND SAVE Z-SCORED PAC AVERAGED ACROSS PARTICIPANTS #############
    pac_plot, ax1 = plot_rect_topo_from_epochs(np.mean(pac_zscore_diff_ave, axis=(0)), epochs.info,
                                            title=f'{group}_{cond1}_vs_{cond2}: Averaged z-scored PAC MI',
                                            cmap='PiYG', vmin=-0.5, vmax=0.5)
    plt.savefig(os.path.join(fig_task_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_PAC_MI_AVE_TOPO.png"), dpi=300)

    pac_ave_plot, ax2 = plot_matrix_topo_from_epochs(np.mean(pac_zscore_diff, axis=(0)), epochs.info,
                                                    title=f'{group}_{cond1}_vs_{cond2}: z-scored PAC MI',
                                                    cmap='PiYG', vmin=-0.5, vmax=0.5)
    plt.savefig(os.path.join(fig_task_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_PAC_MI_TOPO.png"), dpi=300)


    ############# RUN CLUSTER-BASED PERMUTATION TEST #############
    tail = 0 # two-tailed test

    # Set the threshold for including data bins in clusters with t-value corresponding to p=0.01
    # Because we conduct a two-tailed test, we divide the p-value by 2 (which means we're making use of both tails of the distribution).
    # As the degrees of freedom, we specify the number of observations (here: subjects) minus 1.
    # Finally, we subtract 0.01 / 2 from 1, to get the critical t-value on the right tail
    degrees_of_freedom = pac_zscore_diff.shape[0] - 1
    t_thresh = scipy.stats.t.ppf(1 - 0.01 / 2, df=degrees_of_freedom)

    #!
    # threshold_tfce = dict(start=0, step=0.2) # Threshold-free cluster enhancement (TFCE) - more conservative, similar results

    # Set the number of permutations
    n_permutations = 10000

    # Run the analysis
    T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(
        pac_zscore_diff_ave,
        n_permutations=n_permutations,
        threshold=t_thresh,
        tail=tail,
        adjacency=sensor_adjacency,
        out_type="mask",
        max_step=1,
        verbose=True,
    )

    # # Save the results
    np.save(os.path.join(pac_stats_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_freqs_ave_T_obs.npy"), T_obs)
    np.save(os.path.join(pac_stats_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_freqs_ave_clusters.npy"), np.array(clusters, dtype=object))
    np.save(os.path.join(pac_stats_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_freqs_ave_cluster_p_values.npy"), cluster_p_values)
    np.save(os.path.join(pac_stats_save_path, f"pac_mi_{group}_{cond1}_vs_{cond2}_freqs_ave_H0.npy"), H0)

    # SANITY CHECKS
    print(f't_thresh = {t_thresh}')
    print(f'T_obs_mean = {T_obs.mean()}')
    print(f'cluster_p_values = {cluster_p_values}')

    alpha = 0.05  # significance threshold
    significant_clusters = [i for i, p in enumerate(cluster_p_values) if p < alpha]
    print(f"{cond1} vs {cond2} found {len(significant_clusters)} significant clusters")


    ####### PLOT THE RESULTS #######
    plot_significant_topomap(T_obs, clusters, cluster_p_values, epochs.info, group=group, task=f'{cond1}_vs_{cond2}')
    plt.savefig(os.path.join(fig_task_save_path, f"pac_cluster_stats_{group}_{cond1}_vs_{cond2}_freq_ave_TOPO.png"), dpi=300)

Reading D:\BonoKat\research project\# study 1\eeg_data\set\Y\s1_pac_sub01\preproc\analysis\s1_pac_sub01_BL_epochs_plan-epo.fif ...
    Read a total of 1 projection items:
        Average EEG reference (1 x 60) active
    Found the data of interest:
        t =    -500.00 ...     500.00 ms
        0 CTF compensation matrices available
Adding metadata with 10 columns
99 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Could not find a adjacency matrix for the data. Computing adjacency based on Delaunay triangulations.
-- number of adjacent vertices : 60
Adjacency matrix shape: (60, 60)
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for BL_go vs BL_plan...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-1.9718746677943155 max=4.073190748102221
Running initial clustering …
Found 1 cluster


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.1128458908976368
cluster_p_values = [0.0024]
BL_go vs BL_plan found 1 significant clusters
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for MAIN_go_baseline vs MAIN_plan_baseline...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-2.5916333765074397 max=3.383751714351727
Running initial clustering …
Found 2 clusters


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.04032781932156266
cluster_p_values = [0.1237 0.1414]
MAIN_go_baseline vs MAIN_plan_baseline found 0 significant clusters
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for MAIN_go_adaptation vs MAIN_plan_adaptation...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-3.103132891377297 max=2.9219920131463466
Running initial clustering …
Found 2 clusters


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.16341196456330243
cluster_p_values = [0.0165 0.2063]
MAIN_go_adaptation vs MAIN_plan_adaptation found 1 significant clusters
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for MAIN_go_combined vs MAIN_plan_combined...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-2.8333873900868785 max=3.905083696848699
Running initial clustering …
Found 3 clusters


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.0967945016733594
cluster_p_values = [0.043  0.0489 0.3857]
MAIN_go_combined vs MAIN_plan_combined found 2 significant clusters
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for MAIN_go_combined vs BL_go...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-3.985420401633887 max=3.8595566351648096
Running initial clustering …
Found 5 clusters


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = 0.01755145336977496
cluster_p_values = [0.144  0.2796 0.0041 0.0043 0.234 ]
MAIN_go_combined vs BL_go found 2 significant clusters
z-scored PAC array shape: (23, 60, 20, 20)
z-scored PAC array shape: (23, 60, 20, 20)
Running stats for MAIN_plan_combined vs BL_plan...
z-scored diff PAC array shape: (23, 60, 20, 20)
(23, 60)
stat_fun(H1): min=-4.168021460504441 max=4.221074647721624
Running initial clustering …
Found 4 clusters


  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = 0.06656733633161584
cluster_p_values = [0.2912 0.0039 0.1088 0.0028]
MAIN_plan_combined vs BL_plan found 2 significant clusters


_________________________________

# SOURCES

In [45]:
group = 'Y'
eeg_data_dir = f'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set\\{group}'
subjects = os.listdir(eeg_data_dir)

# Create directories for saving stats
group_save_path = f'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set\\{group} group'
group_source_save_path = os.path.join(group_save_path, 'sources')
check_paths(group_source_save_path)
pac_stats_path = os.path.join(group_source_save_path, 'pac_stats')
check_paths(pac_stats_path)
pac_stats_save_path = os.path.join(pac_stats_path, 'conditions')
check_paths(pac_stats_save_path)

# # Create directories for saving figures
# fig_group_path = os.path.join(group_source_save_path, 'pac_stats', 'figs', 'sources')
# check_paths(fig_group_path)
# fig_group_save_path = os.path.join(fig_group_path, group)
# fig_task_save_path = os.path.join(fig_group_path, 'conditions')
# check_paths(fig_task_save_path)

############### CREATE ADJACENCY MATRIX FOR STATISTICAL TEST #############
# find_ch_adjacency first attempts to find an existing "neighbor"
# (adjacency) file for given sensor layout.
# If such a file doesn't exist, an adjacency matrix is computed on the fly,
# using Delaunay triangulations.
src_fname = 'D:\\BonoKat\\research project\\# study 1\\mri_data\\fs_output\\freesurfer\\sub_dir\\Y\\fsaverage_bem\\bem\\fsaverage-ico4-src.fif'
src = mne.read_source_spaces(src_fname)
# src.plot(subjects_dir='D:\\BonoKat\\research project\\# study 1\\mri_data\\fs_output\\freesurfer\\sub_dir\\Y')
source_adjacency = mne.spatial_src_adjacency(src)
adj = source_adjacency.tocsr()  # Ensure adjacency is CSR format for fast indexing
print('Adjacency shape:', source_adjacency.shape)

############### RUN STATISTICAL COMPARISON #############
for cond1, cond2, data1, data2 in iterate_comparisons(eeg_data_dir, subjects, conditions, comparisons, level='sources'):
    print(f"Running stats for {cond1} vs {cond2}...")

    pac_zscore_diff = data1 - data2
    print('z-scored diff PAC array shape:', pac_zscore_diff.shape) # (24, 60, 20, 20) subs x electrodes x ph_freqs x amp_freqs

    # # Save the PAC data
    np.save(os.path.join(pac_stats_save_path, f"PAC_MI_SOURCE_{group}_{cond1}_vs_{cond2}_ZSCORE_freqs_ave.npy"), pac_zscore_diff_ave)

    ############# RUN CLUSTER-BASED PERMUTATION TEST #############
    tail = 0 # two-tailed test

    # Set the threshold for including data bins in clusters with t-value corresponding to p=0.01
    # Because we conduct a two-tailed test, we divide the p-value by 2 (which means we're making use of both tails of the distribution).
    # As the degrees of freedom, we specify the number of observations (here: subjects) minus 1.
    # Finally, we subtract 0.01 / 2 from 1, to get the critical t-value on the right tail
    degrees_of_freedom = pac_zscore_diff.shape[0] - 1
    t_thresh = scipy.stats.t.ppf(1 - 0.01 / 2, df=degrees_of_freedom)

    #!
    # threshold_tfce = dict(start=0, step=0.2) # Threshold-free cluster enhancement (TFCE) - more conservative, similar results

    # Set the number of permutations
    n_permutations = 10000

    # Run the analysis
    T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(
        pac_zscore_diff,
        n_permutations=n_permutations,
        threshold=t_thresh,
        tail=tail,
        adjacency=adj,
        out_type="mask",
        max_step=1,
        verbose=True,
    )

    # # Save the results
    np.save(os.path.join(pac_stats_save_path, f"PAC_MI_SOURCE_{group}_{cond1}_vs_{cond2}_freqs_ave_T_obs.npy"), T_obs)
    np.save(os.path.join(pac_stats_save_path, f"PAC_MI_SOURCE_{group}_{cond1}_vs_{cond2}_freqs_ave_clusters.npy"), np.array(clusters, dtype=object))
    np.save(os.path.join(pac_stats_save_path, f"PAC_MI_SOURCE_{group}_{cond1}_vs_{cond2}_freqs_ave_cluster_p_values.npy"), cluster_p_values)
    np.save(os.path.join(pac_stats_save_path, f"PAC_MI_SOURCE_{group}_{cond1}_vs_{cond2}_freqs_ave_H0.npy"), H0)

    # SANITY CHECKS
    print(f't_thresh = {t_thresh}')
    print(f'T_obs_mean = {T_obs.mean()}')
    print(f'cluster_p_values = {cluster_p_values}')

    alpha = 0.05  # significance threshold
    significant_clusters = [i for i, p in enumerate(cluster_p_values) if p < alpha]
    print(f"{cond1} vs {cond2} found {len(significant_clusters)} significant clusters")


    Reading a source space...
    Computing patch statistics...
    Patch information added...
    [done]
    Reading a source space...
    Computing patch statistics...
    Patch information added...
    [done]
    2 source spaces read
-- number of adjacent vertices : 5124
Adjacency shape: (5124, 5124)
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 9 vertices with NaN values.
NaN imputation for s1_pac_sub22 complete.
Found 17 vertices with NaN values.
NaN imputation for s1_pac_sub24 complete.
Found 15 vertices with NaN values.
NaN imputation for s1_pac_sub26 complete.
Found 17 vertices with NaN values.
NaN imputation for s1_pac_sub29 complete.
Found 9 vertices with NaN values.
NaN imputation for s1_pac_sub32 complete.
Found 8

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.032718782647744006
cluster_p_values = [0.995  1.     1.     1.     0.9981 0.4587 1.     0.1346 0.9999 0.9364
 1.     0.7163 1.     1.     0.5713 0.9999 1.     1.     0.972  1.
 1.     0.5104 1.     1.     1.     1.     0.8678 1.     1.     1.
 1.     0.9863 1.     0.9994 1.     0.9484 0.9985 1.     1.     0.9999
 1.     1.     0.9416 0.9535 1.     0.7749 1.     1.     0.9094 1.
 1.     1.    ]
BL_go vs BL_plan found 0 significant clusters
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 9 vertices with NaN values.
NaN imputation for s1_pac_sub22 complete.
Found 17 vertices with NaN values.
NaN imputation for s1_pac_sub24 complete.
Found 15 vertices with NaN values.
NaN imputation for 

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = 0.0008521713092835002
cluster_p_values = [1.     0.9998 1.     1.     1.     0.9954 0.9996 0.9412 1.     1.
 1.     1.     0.7106 1.     1.     1.     1.     1.     1.     0.9216
 0.9762 1.     1.     1.     1.     1.     0.9995 1.     1.     1.    ]
MAIN_go_baseline vs MAIN_plan_baseline found 0 significant clusters
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 9 vertices with NaN values.
NaN imputation for s1_pac_sub22 complete.
Found 17 vertices with NaN values.
NaN imputation for s1_pac_sub24 complete.
Found 15 vertices with NaN values.
NaN imputation for s1_pac_sub26 complete.
Found 17 vertices with NaN values.
NaN imputation for s1_pac_sub29 complete.
Found 9 vertices with NaN v

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.01554557780358233
cluster_p_values = [0.9999 1.     1.     1.     1.     1.     1.     1.     0.9341 0.9745
 0.9972 1.     1.     1.     0.9992 1.     0.9981 1.     1.     0.9998
 0.9343 1.     0.9998 0.9844]
MAIN_go_adaptation vs MAIN_plan_adaptation found 0 significant clusters
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 13 vertices wit

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = -0.0035445533477220133
cluster_p_values = [1.     1.     1.     1.     1.     0.9998 0.9905 1.     1.     0.9896
 1.     1.     1.     1.     0.9765 1.     0.9695 1.     0.8977 0.9995
 1.     0.9989 1.    ]
MAIN_go_combined vs MAIN_plan_combined found 0 significant clusters
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 13 vertices with NaN val

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = 0.005764545521538693
cluster_p_values = [1.     0.9977 1.     1.     1.     1.     1.     0.2038 1.     0.9935
 1.     0.879  1.     0.9616 0.9962 1.     0.9925 1.     0.9977]
MAIN_go_combined vs BL_go found 0 significant clusters
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Found 5 vertices with NaN values.
NaN imputation for s1_pac_sub01 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub07 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Found 3 vertices with NaN values.
NaN imputation for s1_pac_sub10 complete.
Blocks stacked shape: (2, 5124, 20, 20)
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complete.
Found 13 vertices with NaN values.
NaN imputation for s1_pac_sub11 complet

  0%|          | Permuting : 0/9999 [00:00<?,       ?it/s]

t_thresh = 2.818756060596369
T_obs_mean = 0.0008294374481221358
cluster_p_values = [1.     0.757  1.     1.     1.     0.9999 1.     1.     1.     1.
 0.9999 1.     1.     0.9999 1.     0.9991 1.     1.     0.8414 0.8809
 0.8138 0.9983 1.     1.     0.9983 1.     1.     1.    ]
MAIN_plan_combined vs BL_plan found 0 significant clusters
