#  Analysis code for mVEP BCI EEG sensor domain analysis 

### Written by Joshua Kosnoff except where otherwise noted.

## Load in Preprocessed Data

In [1]:
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
import mne
import numpy as np
import pandas as pd

my_df = pd.read_pickle("Analysis_Rerun.pkl")

In [2]:
print("Total N: ", len(np.unique(my_df.subject)))
print("tFUS-GC N: ", len(np.unique(my_df["subject"].loc[my_df.condition == "tFUS-GC"])))
print("tFUS-CP N: ", len(np.unique(my_df["subject"].loc[my_df.condition == "tFUS-GP"])))
print("Non-Modulated N: ", len(np.unique(my_df["subject"].loc[my_df.condition == "Non-Modulated"])))
print("Decoupled-Sham N: ", len(np.unique(my_df["subject"].loc[my_df.condition == "Decoupled-Sham"])))

Total N:  25
tFUS-GC N:  25
tFUS-CP N:  17
Non-Modulated N:  24
Decoupled-Sham N:  19


In [3]:
# Silence mne logs
mne.set_log_level('CRITICAL')

def normalize_channel_nomenclature(data_epoch, montage = "easycap-M1"):
    
    try:
        data_epoch.set_montage(montage, match_case=True)
    except ValueError as e:
        data_epoch.rename_channels({"FP1": "Fp1", "FP2": "Fp2", "FPz": "Fpz"}).set_montage(montage, match_case = True)
    
    return data_epoch

## Permutation Cluster Test

Run a permutation cluster test on the EEG data. The original code for the cluster test was part of an MNE Python 
tutorial found here: https://mne.tools/stable/auto_tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.html. It was modified and functionalized by Joshua Kosnoff.

Please note that there are well documented weaknesses with this test. This test finds significant **clusters**, but that does not guarantee that all the points within the clusters are actually significant, just that, on average, the cluster is significant. As such, caution should be used when analyzing the result to not overstate the significance. 

In [4]:
#### Permutation cluster f-test #####
def permutation_cluster_test(epoch_list,
                             event_ids, 
                             colors = "auto",
                             alpha = 0.05, 
                             parametric=False, 
                             ci = 0.95,
                             n_permutations=10000,
                             plot_combine_method = "mean", 
                             time_frequency = False, 
                             freqs = np.arange(3, 30, 3),
                             seed = 0,
                             test_thresh = 'auto',
                             test = 'f',
                             unit = "µV",
                             scientific_notation=False,
                             bins = 1,
                             yticks = 'auto',
                             tail = 1,
                             n_jobs = 1,
                             save_fig = "",
                            ):
    """
    Run a permutation cluster test and plot the significant clusters
    
    Original code taken extensively from:
    - https://mne.tools/stable/auto_tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.html
    
    Functionalized and edited by Joshua Kosnoff
    
    Inputs:
        epoch_list: a list of MNE Epochs for the conditions to compare
        event_ids: a list of names for each condition
        colors: optional; a dictionary of formal {"Event_id": color code}
        alpha; the alpha value for determining significance
        parametric: boolean for whether to run a parametric or non-parametric test
        n_permutations: the number of permutations to run
        plot_combine_method: 'gfp', 'mean', 'median,' or 'std'; how to visualize the sensor data. 
            Default is global field power. 
        time_frequency: boolean value for whether to consider time-frequency domain
        freqs: the frequency array to consider for time_frequency analysis
        seed: the random seed to use in order to have repeatable findings
        test: the stat test to run. Currently, on F and T-tests are supported
        bins: how many timepoints to bin (average) together for timefrequency tests. Default is no averaging.
        yticks: which yticks to mark on on the time-frequency plots. 
        
    Returns:
        None
    """
    
    # assert test.lower() in ['f', 't'], "Error! Unrecognized test! Only 't' and 'f' are currently supported"
    if (isinstance(test, str)) and (test.lower() == 'f'):
        stat_func = mne.stats.f_oneway
    elif (isinstance(test, str)) and (test.lower() == 't'):
        stat_func = mne.stats.ttest_ind_no_p
    else:
        stat_func = test
    
    # Get channel adjacencies
    adjacency, ch_names = mne.channels.find_ch_adjacency(epoch_list[0].info, ch_type='eeg')
    
    # We are running an F test, so we look at the upper tail
    # see also: https://stats.stackexchange.com/a/73993
                  
    if time_frequency:
        
        # Create a 4D matrix with 
        decim = 1
        n_cycles = freqs / freqs[0]

        epochs_power = list()
        for condition in epoch_list:
            this_tfr = mne.time_frequency.tfr_morlet(
                condition,
                freqs,
                n_cycles=n_cycles,
                decim=decim,
                average=False,
                return_itc=False,
            )
            
            if bins > 1:
                decim_tfr = this_tfr.copy().decimate(bins)
                # print("Original Shape: ", this_tfr._data.shape)
                # print("Decimated Shape: ", decim_tfr._data.shape)
                avg_data = np.stack([np.mean(this_tfr._data[:,:,:,i*bins:(i+1)*bins], axis=-1) for i in range(decim_tfr._data.shape[-1])], axis = -1)
                # print("New shape: ", avg_data.shape)
                decim_tfr._data = avg_data
                this_tfr = decim_tfr
            
            this_tfr.apply_baseline(mode="mean", baseline=(None, 0))
            epochs_power.append(this_tfr.data)
        
        # transpose again to (epochs, frequencies, times, channels)
        X = [np.transpose(x, (0, 2, 3, 1)) for x in epochs_power]

        # our data at each observation is of shape frequencies × times × channels
        adjacency = mne.stats.combine_adjacency(len(freqs), len(this_tfr.times), adjacency)
                  
        # Note that a cluster forming statistic equal to the f distribution will spit out 
        # out huge cluster. As a result, we need to pick something larger, but also kind of arbitrary
        if test_thresh == 'auto':
            test_thresh = 15
    
    else:
        # Obtain the data as a 3D matrix and transpose it such that
        # the dimensions are as expected for the cluster permutation test:
        # n_epochs × n_times × n_channels
    
        X = [i.apply_baseline(baseline=(None, 0)) for i in epoch_list]
        X = [i.get_data() for i in X]
        X = [np.transpose(x, (0, 2, 1)) for x in X]
         
        if test_thresh == 'auto':
            # For an F test we need the degrees of freedom for the numerator
            # (number of conditions - 1) and the denominator (number of observations
            # - number of conditions):
            n_conditions = len(epoch_list)
            n_observations = sum(len(i) for i in X) / len(X)
            dfn = n_conditions - 1
            dfd = n_observations - n_conditions

            # Note: we calculate 1 - alpha to get the critical value
            # on the right tail
            if test.lower() == 'f':
                test_thresh = scipy.stats.f.ppf(1 - alpha, dfn=dfn, dfd=dfd)
            elif test.lower() == 't':
                test_thresh = scipy.stats.t.ppf(1 - alpha, df=dfd)
                
    # run the cluster based permutation analysis
    if parametric:
        # Run the parametric cluster test
        cluster_stats = mne.stats.permutation_cluster_test(
            X,
            n_permutations=n_permutations,
            threshold=test_thresh,
            tail=tail,
            n_jobs=n_jobs,
            buffer_size=None,
            adjacency=adjacency, 
            stat_fun = stat_func,
            seed=seed,
        )
    else:
        # Run the nonparametric cluster test
        cluster_stats = mne.stats.spatio_temporal_cluster_test(
            X,
            n_permutations=n_permutations,
            threshold=test_thresh,
            tail=tail,
            n_jobs=n_jobs,
            buffer_size=None,
            stat_fun = stat_func,
            adjacency=adjacency,
            seed=seed,
        )
        
    F_obs, clusters, p_values, _ = cluster_stats
    print(p_values)
    # Accept the clusters with a p-value less than alpha
    # NOTE: remember the caveats with respect to "significant" clusters
    # good_cluster_inds = np.where(p_values < alpha)[0]
    
    return F_obs, clusters, p_values, _

In [5]:
# Load data
my_df = pd.read_pickle("Analysis_Rerun.pkl")

# Step 1: Normalize the channel nomenclature
my_df["data_epoch"] = my_df["data_epoch"].apply(lambda x: normalize_channel_nomenclature(x))

# Step 2: Find the same channels across all trials
epoch_list = my_df["data_epoch"].to_list()
all_channels = [i.ch_names for i in epoch_list]

# Step 3: Get the overlapping channels
shared_channels = list(set.intersection(*map(set,all_channels)))

# Step 4: Pick the shared channels
my_df["data_epoch"] = my_df["data_epoch"].apply(lambda x: x.pick_channels(shared_channels))

# Initialize empty lists for each condition
tFUSs = []
baselines = []
shams = []
uscs = []

conditions = ["Non-Modulated", "tFUS-GC", "Decoupled-Sham", "tFUS-GP"]
lists = [baselines, tFUSs, shams, uscs]
n_subjects = 0

for subject in np.unique(my_df.subject):
    subj_df = my_df.loc[my_df.subject == subject]
    
    if len(np.unique(subj_df.condition)) == 4:
        n_subjects = n_subjects + 1
        for i, condition in enumerate(conditions):
            evoked = mne.concatenate_epochs(subj_df["data_epoch"].loc[(subj_df.condition == condition)].to_list(), add_offset=True)
            lists[i].append(evoked)


# Equal number of trials per subject
mne.epochs.equalize_epoch_counts(tFUSs, method='mintime')      
mne.epochs.equalize_epoch_counts(baselines, method='mintime')   
mne.epochs.equalize_epoch_counts(shams, method='mintime')      
mne.epochs.equalize_epoch_counts(uscs, method='mintime')   
                
tFUSs = mne.concatenate_epochs(tFUSs)
baselines = mne.concatenate_epochs(baselines)
shams = mne.concatenate_epochs(shams)
uscs = mne.concatenate_epochs(uscs)

In [6]:
epoch_list = [baselines, tFUSs, shams, uscs]
mne.epochs.equalize_epoch_counts(epoch_list, method='mintime')

event_ids = ['NM', 'tFUS-GC', 'DS', 'tFUS-GP']
colors = {"tFUS-GC": "C1", "NM": "C0", 'DS': "C2", "tFUS-GP": 'C3'}


# Nice Plotting values
MEDIUM_SIZE = 15
# BIGGER_SIZE = 20
BIGGER_SIZE = 17

plt.style.use(['default'])
plt.rc('figure', dpi=300.0)
plt.rc('font', family='Arial', size=BIGGER_SIZE)
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=60)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
new_rc_params = {'text.usetex': False,
    "svg.fonttype": 'none'
    }
matplotlib.rcParams.update(new_rc_params)


# Constants for the rm anova
pthresh = 0.05
return_pvals = False
factor_levels = [len(epoch_list)]
effects = "A"

def stat_fun(*args):
    # get f-values only.
    # print(args[0].shape, args[1].shape)
    
    return mne.stats.f_mway_rm(
        np.swapaxes(args, 1, 0),
        factor_levels=factor_levels,
        effects=effects,
        return_pvals=return_pvals,
    )[0]

f_thresh = mne.stats.f_threshold_mway_rm(n_subjects, factor_levels, effects, pthresh)

permutation_cluster_test(epoch_list,
                         event_ids, 
                         colors = colors,
                         alpha = pthresh, 
                         parametric=False, 
                         ci = 0.68, # 68% CI = 1 standard deviation
                         n_permutations=1000,
                         plot_combine_method = "mean", 
                         time_frequency = False, 
                         test_thresh = f_thresh,
                         # freqs = np.arange(4, 40, 1), # allll the frequencies
                         test = stat_fun,
                         scientific_notation=False,
                         n_jobs = -1,
                         save_fig = "permutation_cluster_test.svg",
                        )

[0.788 1.    1.    0.682 1.    1.    1.    0.122 0.022 0.995 1.    0.798
 1.    1.    1.    0.996 0.997 1.    1.    1.    0.994 1.    1.    1.
 1.    1.    0.543 1.    0.187 1.    1.    0.666 1.    0.998 1.    1.
 0.016 1.    0.18  1.    1.    1.    1.    1.    0.591 0.114 0.151 1.
 1.    1.    0.959 1.    0.879 0.897 0.109 0.993 0.998 1.    0.999 1.   ]


(array([[0.87748372, 0.42521734, 1.65509451, ..., 2.48942255, 1.16476077,
         0.26861626],
        [1.15965171, 0.23422854, 3.0097074 , ..., 1.44591533, 0.33762425,
         0.65644142],
        [2.58505437, 2.14768988, 3.11183911, ..., 5.27429506, 4.43494986,
         0.04865458],
        ...,
        [0.40809228, 1.45820247, 2.04404695, ..., 2.41216308, 0.63144951,
         0.58947512],
        [0.40784303, 0.45820385, 1.92093082, ..., 1.8460844 , 0.25955528,
         0.66777014],
        [1.34306046, 0.40099067, 4.63518897, ..., 2.6156346 , 1.27722419,
         0.79666245]]),
 [(array([0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3]),
   array([ 4,  2, 20, 31, 39,  4, 33,  2, 20, 31, 39, 45, 59, 20])),
  (array([0]), array([9])),
  (array([0]), array([50])),
  (array([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3]),
   array([57, 14, 49, 57,  8, 19, 43, 57,  9, 14, 49, 58, 28, 19, 28])),
  (array([1]), array([54])),
  (array([2]), array([38])),
  (array([4]), array([25])),
  (array([ 4

In [7]:
import session_info
session_info.show()