In [43]:
import sys
import os

print(sys.path)
sys.path.append("C:/Users/jz421/Desktop/GlobalLocal/IEEG_Pipelines/") #need to do this cuz otherwise ieeg isn't added to path...

# Get the absolute path to the directory containing the current script
# For GlobalLocal/src/analysis/preproc/make_epoched_data.py, this is GlobalLocal/src/analysis/preproc
try:
    # This will work if running as a .py script
    current_file_path = os.path.abspath(__file__)
    current_script_dir = os.path.dirname(current_file_path)
except NameError:
    # This will be executed if __file__ is not defined (e.g., in a Jupyter Notebook)
    # os.getcwd() often gives the directory of the notebook,
    # or the directory from which the Jupyter server was started.
    current_script_dir = os.getcwd()

# Navigate up three levels to get to the 'GlobalLocal' directory
project_root = os.path.abspath(os.path.join(current_script_dir, '..'))

# Add the 'GlobalLocal' directory to sys.path if it's not already there
if project_root not in sys.path:
    sys.path.insert(0, project_root) # insert at the beginning to prioritize it

from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from bids import BIDSLayout
import mne
import numpy as np
import copy
import seaborn as sns
import pandas as pd
from scipy import stats
import math

import pickle
from src.analysis.utils.general_utils import get_good_data
from src.analysis.preproc.make_epoched_data import trial_ieeg_rand_offset


['c:\\Users\\jz421\\Desktop\\GlobalLocal', 'c:\\Users\\jz421', 'C:\\Users\\jz421\\Desktop\\GlobalLocal\\IEEG_Pipelines', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\python311.zip', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\DLLs', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\Lib', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg', '', 'C:\\Users\\jz421\\AppData\\Roaming\\Python\\Python311\\site-packages', 'C:\\Users\\jz421\\AppData\\Roaming\\Python\\Python311\\site-packages\\win32', 'C:\\Users\\jz421\\AppData\\Roaming\\Python\\Python311\\site-packages\\win32\\lib', 'C:\\Users\\jz421\\AppData\\Roaming\\Python\\Python311\\site-packages\\Pythonwin', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\Lib\\site-packages', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\Lib\\site-packages\\win32', 'c:\\Users\\jz421\\AppData\\Local\\anaconda3\\envs\\ieeg\\Lib\\site-packages\\win32\\lib', 'c:\\Users\\jz421\\AppData\\Local\\anac

In [29]:
if LAB_root is None:
    HOME = os.path.expanduser("~")
    if os.name == 'nt': # windows
        LAB_root = os.path.join(HOME, "Box", "CoganLab")
    else: # mac
        LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box", "CoganLab")

layout = get_data(task, root=LAB_root)

### actually make the baselines so we can compare them later

In [22]:
def get_baseline(inst: mne.io.BaseRaw, baseline_event: str, base_times: tuple[float, float], pad_length: float, base_times_length: float, outliers: int, passband: tuple[float, float], dec_factor: int, rand_offset: bool):
    inst = inst.copy()
    inst.load_data()
    ch_type = inst.get_channel_types(only_data_chs=True)[0]
    inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)

    if rand_offset:
        trials = trial_ieeg_rand_offset(inst, baseline_event, base_times, base_times_length, pad_length, preload=True)
    else:
        adjusted_base_times = [base_times[0] - pad_length, base_times[1] + pad_length]
        trials = trial_ieeg(inst, baseline_event, adjusted_base_times, preload=True)

    outliers_to_nan(trials, outliers=outliers)
    HG_base = gamma.extract(trials, passband=passband, copy=False, n_jobs=1)
    crop_pad(HG_base, f"{pad_length}s")
    HG_base.decimate(dec_factor)

    del inst
    return HG_base

In [26]:
def make_baselines(subjects, task='GlobalLocal', baseline_events = {"experimentStart": (1, 101), "Stimulus": (-0.5, 0)}, LAB_root=None, channels=None, pad_length = 0.5, base_times_length = 0.5, outliers = 10, passband = (70,150), dec_factor = 8, rand_offset = False):
    '''
    Creates baselines for a list of subjects and saves them to a specified directory.
    
    Parameters:
    subjects: list of subjects to create baselines for
    task: task to create baselines for
    baseline_events: dictionary of baseline events and their corresponding times. If contains "randoffset", then random offset is applied to the baseline times.
    LAB_root: root directory of the lab data
    channels: channels to use for baseline creation
    pad_length: length of padding to use for baseline creation
    base_times_length: length of baseline times to use for baseline creation
    outliers: number of outliers to use for baseline creation
    passband: passband to use for baseline creation
    dec_factor: decimation factor to use for baseline creation
    '''
    if LAB_root is None:
        HOME = os.path.expanduser("~")
        if os.name == 'nt': # windows
            LAB_root = os.path.join(HOME, "Box", "CoganLab")
        else: # mac
            LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box", "CoganLab")

    layout = get_data(task, root=LAB_root)

    for sub in subjects:
        good = get_good_data(sub, layout)

        for baseline_event, base_times in baseline_events.items():
            HG_base = get_baseline(good, baseline_event, base_times, pad_length, base_times_length, outliers, passband, dec_factor, rand_offset)

            save_dir = os.path.join(layout.root, 'derivatives', 'tests', 'baselineTesting', baseline_event, sub)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            if rand_offset:
                output_name_base = f"{base_times_length}sec_within{base_times[0]}-{base_times[1]}sec_randoffset_{baseline_event}Base_decFactor_{dec_factor}_outliers_{outliers}_passband_{passband[0]}-{passband[1]}_padLength_{pad_length}s"
            else:
                output_name_base = f"{base_times_length}sec_within{base_times[0]}-{base_times[1]}sec_{baseline_event}Base_decFactor_{dec_factor}_outliers_{outliers}_passband_{passband[0]}-{passband[1]}_padLength_{pad_length}s"
            
            HG_base.save(os.path.join(save_dir, f"{sub}_{output_name_base}_base-epo.fif"), overwrite=True)

In [38]:
# subjects = ['D0057','D0059', 'D0063', 'D0065', 'D0069', 'D0071', 'D0077', 'D0090', 'D0094', 'D0100', 'D0102', 'D0103', 'D0107A', 'D0110', 'D0116', 'D0117', 'D0121']
subjects = ['D0057']
task = 'GlobalLocal'
baseline_events = {"experimentStart": (1, 101), "Stimulus": (-0.5, 0), "Stimulus": (-1, 0)}
LAB_root = None
channels = None
pad_length = 0.5
base_times_length = 0.5
outliers = 10
passband = (70,150)
dec_factor = 8
rand_offset = True

In [34]:
make_baselines(subjects, task, baseline_events, LAB_root, channels, pad_length, base_times_length, outliers, passband, dec_factor, rand_offset)

Extracting EDF parameters from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-01_desc-clean_ieeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading events from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-01_desc-clean_events.tsv.
Reading channel info from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-01_desc-clean_channels.tsv.
Reading electrode coords from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_acq-01_space-ACPC_electrodes.tsv.
Extracting EDF parameters from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-02_desc-clean_ieeg.edf...
EDF file detected
Setting channel info str

  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)


Reading events from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-02_desc-clean_events.tsv.
Reading channel info from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-02_desc-clean_channels.tsv.
Reading electrode coords from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_acq-01_space-ACPC_electrodes.tsv.
Extracting EDF parameters from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-03_desc-clean_ieeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)


Reading events from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-03_desc-clean_events.tsv.
Reading channel info from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-03_desc-clean_channels.tsv.
Reading electrode coords from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_acq-01_space-ACPC_electrodes.tsv.
Extracting EDF parameters from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-04_desc-clean_ieeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)


Reading events from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-04_desc-clean_events.tsv.
Reading channel info from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_task-GlobalLocal_acq-01_run-04_desc-clean_channels.tsv.
Reading electrode coords from C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\clean\sub-D0057\ieeg\sub-D0057_acq-01_space-ACPC_electrodes.tsv.


  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)
  new_raw = read_raw_bids(bids_path=BIDS_path, verbose=verbose)


<RawEDF | sub-D0057_task-GlobalLocal_acq-01_run-01_desc-clean_ieeg.edf, 178 x 8243200 (4025.0 s), ~201 kB, data not loaded>
outlier round 1 channels: ['RAMT8']
outlier round 2 channels: ['RAMT8', 'RPI16']
outlier round 2 channels: ['RAMT8', 'RPI16', 'LAMT12']
Reading 0 ... 6198083  =      0.000 ...  3026.408 secs...
Applying average reference.
Applying a custom ('sEEG',) reference.
Applying average reference.
Applying a custom ('sEEG',) reference.
Used Annotations descriptions: ['1', '2', 'Response/c25.0/n25.0/BigLetters/SmallLetters/Taskl/TargetLetters/Responded1.0/ParticipantResponse115.0/CorrectResponse115.0/TrialCount225.0/BlockTrialCount1.0/ReactionTime2149.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/CorrectResponse114.0/TrialCount227.0/BlockTrialCount3.0/ReactionTime883.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/CorrectRe

100%|██████████| 1/1 [00:00<00:00,  3.45it/s]
  HG_base.decimate(dec_factor)


Applying average reference.
Applying a custom ('sEEG',) reference.
Used Annotations descriptions: ['1', '2', 'Response/c25.0/n25.0/BigLetters/SmallLetters/Taskl/TargetLetters/Responded1.0/ParticipantResponse115.0/CorrectResponse115.0/TrialCount225.0/BlockTrialCount1.0/ReactionTime2149.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/CorrectResponse114.0/TrialCount227.0/BlockTrialCount3.0/ReactionTime883.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/CorrectResponse114.0/TrialCount228.0/BlockTrialCount4.0/ReactionTime766.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/CorrectResponse114.0/TrialCount241.0/BlockTrialCount17.0/ReactionTime499.0/Accuracy1.0/D57', 'Response/c25.0/r25.0/BigLetterh/SmallLetterh/Taskg/TargetLetterh/Responded1.0/ParticipantResponse114.0/C

100%|██████████| 449/449 [02:15<00:00,  3.32it/s]
  HG_base.decimate(dec_factor)


load in some baselines for testing, and also load in the signal to compare them to

In [39]:
loaded_baselines = {}
loaded_signals = {}

for sub in subjects:
    signal_path = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub, f'{sub}_Stimulus_0.5sec_within1sec_randoffset_preStimulusBase_decFactor_8_outliers_10_passband_70.0-150.0_padLength_0.5s_stat_func_ttest_ind_equal_var_False_HG_ev1-epo.fif')
    loaded_signals[sub] = mne.read_epochs(signal_path)
    loaded_baselines[sub] = {}

    for baseline_event, base_times in baseline_events.items():
        save_dir = os.path.join(layout.root, 'derivatives', 'tests', 'baselineTesting', baseline_event, sub)

        if rand_offset:
            output_name_base = f"{base_times_length}sec_within{base_times[0]}-{base_times[1]}sec_randoffset_{baseline_event}Base_decFactor_{dec_factor}_outliers_{outliers}_passband_{passband[0]}-{passband[1]}_padLength_{pad_length}s"
        else:
            output_name_base = f"{base_times_length}sec_within{base_times[0]}-{base_times[1]}sec_{baseline_event}Base_decFactor_{dec_factor}_outliers_{outliers}_passband_{passband[0]}-{passband[1]}_padLength_{pad_length}s"
        
        HG_base = mne.read_epochs(os.path.join(save_dir, f"{sub}_{output_name_base}_base-epo.fif"))
        loaded_baselines[sub][baseline_event] = HG_base

Reading C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\freqFilt\figs\D0057\D0057_Stimulus_0.5sec_within1sec_randoffset_preStimulusBase_decFactor_8_outliers_10_passband_70.0-150.0_padLength_0.5s_stat_func_ttest_ind_equal_var_False_HG_ev1-epo.fif ...


    Found the data of interest:
        t =   -1000.00 ...    1500.00 ms
        0 CTF compensation matrices available
Not setting metadata
449 matching events found
No baseline correction applied
0 projection items activated
Reading C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\tests\baselineTesting\experimentStart\D0057\D0057_0.5sec_within1-101sec_randoffset_experimentStartBase_decFactor_8_outliers_10_passband_70-150_padLength_0.5s_base-epo.fif ...
    Found the data of interest:
        t =    1000.00 ...    1500.00 ms
        0 CTF compensation matrices available
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Reading C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\tests\baselineTesting\Stimulus\D0057\D0057_0.5sec_within-1-0sec_randoffset_StimulusBase_decFactor_8_outliers_10_passband_70-150_padLength_0.5s_base-epo.fif ...
    Found the data of interest:
        t =   -1000.00 ...    -500.

now let's play around with the loaded data

signal vs baseline

In [None]:
# --- Main Loop ---
# Loop through each subject
for sub in subjects:
    signal = loaded_signals[sub]
    print(f"--- Processing Subject: {sub} ---")

    # Loop through each baseline condition you've loaded
    for baseline_event, base_times in baseline_events.items():
        base = loaded_baselines[sub][baseline_event]
        
        # Define save directory for the output figures
        save_dir = os.path.join(layout.root, 'derivatives', 'tests', 'baselineTesting', baseline_event, sub, 'figs')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Find channels that are common to both signal and baseline epochs
        common_channels = sorted(list(set(signal.ch_names) & set(base.ch_names)))
        
        print(f"\n-- Comparing with baseline: {baseline_event} {base_times} --")
        print(f"Found {len(common_channels)} common channels. Generating plots...")

        # --- Grid Plotting Setup ---
        n_channels = len(common_channels)
        if n_channels == 0:
            print("No common channels found, skipping.")
            continue
            
        rows = 6
        cols = 10
        plots_per_fig = rows * cols
        n_figs = math.ceil(n_channels / plots_per_fig)

        # Create a figure for each block of 60 channels
        for fig_num in range(n_figs):
            # Create the subplot grid for the current figure
            fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
            
            # Get the slice of channels for this figure
            start_idx = fig_num * plots_per_fig
            end_idx = min(start_idx + plots_per_fig, n_channels)
            channels_for_fig = common_channels[start_idx:end_idx]

            # --- Channel-level Processing and Plotting ---
            for i, ch_name in enumerate(channels_for_fig):
                
                # Determine the current subplot axis
                ax = axes.flat[i]

                # --- 1. Extract and Prepare Data ---
                # Get all data points for the current channel across all trials and time
                signal_data = signal.get_data(picks=ch_name).flatten()
                base_data = base.get_data(picks=ch_name).flatten()
                
                # Remove any potential NaN values from outlier rejection
                signal_data = signal_data[~np.isnan(signal_data)]
                base_data = base_data[~np.isnan(base_data)]
                
                # Continue to next channel if no data is present after cleaning
                if signal_data.size == 0 or base_data.size == 0:
                    ax.set_title(f'{ch_name}\n(No data)', fontsize=8)
                    ax.axis('off')
                    continue

                # # --- 2. Perform Statistical Tests ---
                # # Welch's t-test, which does not assume equal variance
                # t_stat, p_val = stats.ttest_ind(signal_data, base_data, equal_var=False, nan_policy='omit')
                
                # # Simple mean difference
                # mean_diff = np.mean(signal_data) - np.mean(base_data)
                
                # --- 3. Visualize the Distributions ---
                # Create a combined DataFrame for easier plotting with seaborn
                df_signal = pd.DataFrame({'Activity': signal_data, 'Type': 'Signal'})
                df_base = pd.DataFrame({'Activity': base_data, 'Type': f'Baseline'})
                df_combined = pd.concat([df_signal, df_base])
                
                # Create the KDE plot on the specific subplot axis
                sns.kdeplot(data=df_combined, x='Activity', hue='Type', fill=True, common_norm=False, ax=ax, legend=i==0)
                
                # --- 4. Annotate and Finalize Subplot ---
                ax.set_title(f'Channel: {ch_name}', fontsize=10)
                ax.set_xlabel('Time-Averaged Activity', fontsize=8)
                ax.set_ylabel('Density', fontsize=8)
                ax.tick_params(axis='x', labelsize=7)
                ax.tick_params(axis='y', labelsize=7)

                # # Add statistical results to the plot for quick reference
                # stats_text = (f"t={t_stat:.2f}, p={p_val:.3f}\n"
                #               f"Δ={mean_diff:.2f}")
                # ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, fontsize=8,
                #         verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', fc='wheat', alpha=0.6))

            # --- Figure-level Cleanup and Saving ---
            # Turn off unused subplots
            for i in range(n_channels, plots_per_fig):
                axes.flat[i].axis('off')

            fig.suptitle(f'Trial Distribution Comparison for Subject: {sub}\nBaseline: {baseline_event} {base_times} (Page {fig_num + 1}/{n_figs})', fontsize=20)
            fig.tight_layout(rect=[0, 0.03, 1, 0.96]) # Adjust for suptitle
            
            # Save the complete figure
            save_name = f"{sub}_{baseline_event}_time_averaged_signal_vs_baseline_page_{fig_num + 1}.png"
            save_path = os.path.join(save_dir, save_name)
            plt.savefig(save_path, dpi=300)
            print(f"Saved figure to: {save_path}")
            plt.close(fig) # Close the figure to free up memory

--- Processing Subject: D0057 ---

-- Comparing with baseline: experimentStart (1, 101) --
Found 175 common channels. Generating plots...
Saved figure to: C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\tests\baselineTesting\experimentStart\D0057\D0057_experimentStart_signal_vs_baseline_page_1.png


KeyboardInterrupt: 

all baselines vs signal

In [None]:
# --- Main Loop ---
# Loop through each subject
for sub in subjects:
    signal = loaded_signals[sub]
    print(f"--- Processing Subject: {sub} ---")

    # Define a single save directory for the multi-baseline comparison plots
    save_dir = os.path.join(layout.root, 'derivatives', 'tests', 'baselineTesting', sub, 'figs')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # First, find the set of channels that are common to the signal AND ALL baseline files
    # This ensures every plot on the grid is valid for all distributions.
    common_channels = set(signal.ch_names)
    for baseline_event in baseline_events.keys():
        base = loaded_baselines[sub][baseline_event]
        common_channels.intersection_update(base.ch_names)
    
    # Sort the channels for consistent ordering
    common_channels = sorted(list(common_channels))
    
    print(f"Found {len(common_channels)} channels common to signal and all baselines. Generating plots...")

    # --- Grid Plotting Setup ---
    n_channels = len(common_channels)
    if n_channels == 0:
        print("No common channels found across all conditions, skipping subject.")
        continue
        
    rows = 6
    cols = 10
    plots_per_fig = rows * cols
    n_figs = math.ceil(n_channels / plots_per_fig)

    # --- Figure Generation Loop ---
    # Create a new figure for each block of 60 channels
    for fig_num in range(n_figs):
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4.5))
        
        # Get the slice of channels for the current figure
        start_idx = fig_num * plots_per_fig
        end_idx = min(start_idx + plots_per_fig, n_channels)
        channels_for_fig = common_channels[start_idx:end_idx]

        # --- Channel-level Processing and Plotting ---
        for i, ch_name in enumerate(channels_for_fig):
            ax = axes.flat[i]

            # --- 1. Aggregate Data from Signal and ALL Baselines ---
            dfs_to_concat = []

            # a) Get signal data
            signal_data = signal.get_data(picks=ch_name).flatten()
            signal_data = signal_data[~np.isnan(signal_data)]
            if signal_data.size > 0:
                dfs_to_concat.append(pd.DataFrame({'Activity': signal_data, 'Type': 'Signal'}))

            # b) Get data for each baseline
            for baseline_event, base_times in baseline_events.items():
                base = loaded_baselines[sub][baseline_event]
                base_data = base.get_data(picks=ch_name).flatten()
                base_data = base_data[~np.isnan(base_data)]
                
                # Create a concise label for the legend
                baseline_label = f'Base: {baseline_event} {base_times}'
                if base_data.size > 0:
                    dfs_to_concat.append(pd.DataFrame({'Activity': base_data, 'Type': baseline_label}))

            # Continue to next channel if no data was aggregated
            if not dfs_to_concat:
                ax.set_title(f'{ch_name}\n(No data)', fontsize=8)
                ax.axis('off')
                continue

            # c) Combine into a single DataFrame for plotting
            df_combined = pd.concat(dfs_to_concat)

            # --- 2. Visualize the Distributions ---
            # Create a color palette to distinguish signal from baselines
            palette = ["#E41A1C"] + sns.color_palette("viridis_r", len(baseline_events))

            # Create the KDE plot on the specific subplot axis
            # The legend is only created for the very first plot to avoid clutter
            sns.kdeplot(data=df_combined, x='Activity', hue='Type', palette=palette, 
                        fill=True, common_norm=False, ax=ax, legend=(i == 0))
            
            # --- 3. Annotate and Finalize Subplot ---
            ax.set_title(f'Channel: {ch_name}', fontsize=10)
            ax.set_xlabel('High Gamma Activity', fontsize=8)
            ax.set_ylabel('Density', fontsize=8)
            ax.tick_params(axis='x', labelsize=7, rotation=30)
            ax.tick_params(axis='y', labelsize=7)
            # You can uncomment the statistics section here if you decide to add it back

        # --- Figure-level Cleanup and Saving ---
        # Turn off any unused subplots in the grid
        for i in range(len(channels_for_fig), plots_per_fig):
            axes.flat[i].axis('off')

        fig.suptitle(f'Signal vs. All Baselines for Subject: {sub} (Page {fig_num + 1}/{n_figs})', fontsize=20)
        fig.tight_layout(rect=[0, 0.03, 1, 0.96]) # Adjust layout to make room for suptitle
        
        # Save the complete figure
        save_name = f"{sub}_multi_baseline_vs_signal_comparison_page_{fig_num + 1}.png"
        save_path = os.path.join(save_dir, save_name)
        plt.savefig(save_path, dpi=300)
        print(f"Saved figure to: {save_path}")
        plt.close(fig) # Close the figure to free up memory

### average across trials and timepoints but not channels, plot avg mean and stdev for each channel

In [20]:
data = HG_base_stimulus.get_data()

mean = np.mean(data, axis=(0,2))
stdev = np.std(data,axis=(0,2))

print("Mean: ", mean)
print("Standard Deviation: ", stdev)

# Extract the channel names from the Epochs object
channel_names = HG_base_stimulus.ch_names

# Create a bar plot
plt.figure(figsize=(10, 5))
plt.bar(channel_names, mean, yerr=stdev, align='center', alpha=0.75, ecolor='black', capsize=5)

# Customize the plot
plt.ylabel('Mean Amplitude')
plt.title('Mean and Standard Deviation of Channels')
plt.xticks(rotation=90)  # Rotate the channel names for better readability

# Display the plot
plt.tight_layout()
plt.show()

Mean:  [2.35596596e-05            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan            nan            nan            nan
 7.60397970e-06 7.19475317e-06 6.55293008e-06 6.04449749e-06
            nan            nan            nan            nan
            nan            nan            nan            nan
            nan  

KeyboardInterrupt: 

### average across trials, channels, and time points

In [21]:
data = HG_base.get_data()

mean = np.nanmean(data, axis=(0,1,2))
stdev = np.nanstd(data,axis=(0,1,2))

print("Mean of experiment start baseline: ", mean)
print("Standard Deviation of experiment start baseline: ", stdev)

Mean of experiment start baseline:  7.11109295788856e-06
Standard Deviation of experiment start baseline:  2.467975522575185e-06


In [22]:
data = HG_base_stimulus.get_data()

mean = np.nanmean(data, axis=(0,1,2))
stdev = np.nanstd(data,axis=(0,1,2))

print("Mean of fixation cross baseline: ", mean)
print("Standard Deviation of fixation cross baseline: ", stdev)

Mean of fixation cross baseline:  6.0587577502376924e-06
Standard Deviation of fixation cross baseline:  4.567779012040867e-06


### average across trials and channels but not timepoints

In [23]:
data = HG_base.get_data()

mean = np.nanmean(data, axis=(0,1))
stdev = np.nanstd(data,axis=(0,1))

print("Mean: ", mean)
print("Standard Deviation: ", stdev)

# Extract the channel names from the Epochs object
channel_names = HG_base.ch_names


# Now, create a plot
plt.figure(figsize=(10, 5))
plt.plot(HG_base.times, mean, label='Mean')
plt.fill_between(HG_base.times, mean - stdev, mean + stdev, alpha=0.2, label='STD')

# Customize the plot
plt.ylabel('Mean Amplitude')
plt.xlabel('Time (s)')
plt.title('Mean and Standard Deviation of timepoints across trials and channels')
plt.legend(loc='upper right')

# Display the plot
plt.tight_layout()
plt.show()

Mean:  [6.21817376e-06 6.23033093e-06 6.24244690e-06 ... 7.53210802e-06
 7.54389375e-06 7.55552610e-06]
Standard Deviation:  [1.16503499e-06 1.17829252e-06 1.19179487e-06 ... 2.64761902e-06
 2.64209893e-06 2.63672973e-06]


KeyboardInterrupt: 

In [24]:
data = HG_base_stimulus.get_data()

mean = np.nanmean(data, axis=(0,1))
stdev = np.nanstd(data,axis=(0,1))

print("Mean: ", mean)
print("Standard Deviation: ", stdev)

# Extract the channel names from the Epochs object
channel_names = HG_base_stimulus.ch_names


# Now, create a plot
plt.figure(figsize=(10, 5))
plt.plot(HG_base_stimulus.times, mean, label='Mean')
plt.fill_between(HG_base_stimulus.times, mean - stdev, mean + stdev, alpha=0.2, label='STD')

# Customize the plot
plt.ylabel('Mean Amplitude')
plt.xlabel('Time (s)')
plt.title('Mean and Standard Deviation of timepoints across trials and channels')
plt.legend(loc='upper right')

# Display the plot
plt.tight_layout()
plt.show()

Mean:  [5.96941022e-06 5.96859328e-06 5.96779262e-06 ... 6.24856718e-06
 6.24944321e-06 6.25032414e-06]
Standard Deviation:  [4.04477745e-06 4.03667350e-06 4.02874892e-06 ... 4.98159151e-06
 4.97979517e-06 4.97793714e-06]


KeyboardInterrupt: 