# Autoreject and Spike Analysis

Here, we run autoreject on the final ICA preprocessed data and then perform Spike analysis.

In [1]:
# comment if you don't have nb_black installed
%load_ext lab_black

In [184]:
import numpy as np
import scipy
import scipy.io
import pandas as pd
from pathlib import Path
import os
import collections
from natsort import natsorted
import json
import pickle
import warnings
import sys
from copy import copy, deepcopy

warnings.filterwarnings("ignore")

import mne
from mne import make_fixed_length_epochs

mne.set_log_level("ERROR")
from mne_bids import BIDSPath, get_entities_from_fname, get_entity_vals, read_raw_bids
import autoreject
from autoreject import AutoReject, read_auto_reject

from sklearn.preprocessing import OrdinalEncoder, LabelBinarizer, LabelEncoder
from sklearn.multiclass import OneVsRestClassifier
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    brier_score_loss,
    roc_curve,
    average_precision_score,
    roc_auc_score,
    f1_score,
    recall_score,
    jaccard_score,
    balanced_accuracy_score,
    accuracy_score,
    auc,
    precision_score,
    plot_precision_recall_curve,
    average_precision_score,
    precision_recall_curve,
    confusion_matrix,
    cohen_kappa_score,
    make_scorer,
    precision_recall_fscore_support,
)
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (
    StratifiedGroupKFold,
    cross_validate,
    StratifiedShuffleSplit,
    LeaveOneGroupOut,
)
from sklearn.utils import resample
import sklearn
from sklearn import preprocessing
from sklearn.pipeline import make_pipeline

import mne
from mne.time_frequency import read_tfrs

mne.set_log_level("ERROR")
from mne_bids import BIDSPath, get_entities_from_fname, get_entity_vals, read_raw_bids

from eztrack.io import read_derivative_npy

sys.path.append("../../")
from episcalp.features import spike_feature_vector
from episcalp.io.read import load_persyst_spikes, load_reject_log
from episcalp.utils.utils import NumpyEncoder

# if you installed sporf via README
from oblique_forests.sporf import ObliqueForestClassifier

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import re

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define Data Directories

In [23]:
jhroot = Path("/Users/adam2392/Johns Hopkins/Scalp EEG JHH - Documents/bids/")
jeffroot = Path("/Users/adam2392/Johns Hopkins/Jefferson_Scalp - Documents/root/")

# not ready yet
upmcroot = Path("/Users/adam2392/Johns Hopkins/UPMC_Scalp - Documents/")

In [170]:
jh_ica_root = jhroot / "derivatives" / "ICA" / "1-30Hz-30" / "win-20"
jeff_ica_root = jeffroot / "derivatives" / "ICA" / "1-30Hz-30" / "win-20"

In [185]:
# read in spike dataframe
jh_spike_dataset = load_persyst_spikes(jhroot)
jeff_spike_dataset = load_persyst_spikes(jeffroot)

Loading data for subjects: ['jhh001', 'jhh002', 'jhh003', 'jhh004', 'jhh005', 'jhh006', 'jhh007', 'jhh008', 'jhh009', 'jhh010', 'jhh011', 'jhh012', 'jhh013', 'jhh014', 'jhh015', 'jhh016', 'jhh017', 'jhh018', 'jhh019', 'jhh020', 'jhh021', 'jhh022', 'jhh023', 'jhh024', 'jhh025', 'jhh026', 'jhh027', 'jhh028', 'jhh029', 'jhh030', 'jhh101', 'jhh102', 'jhh103', 'jhh104', 'jhh105', 'jhh106', 'jhh107', 'jhh108', 'jhh109', 'jhh110', 'jhh111', 'jhh112', 'jhh113', 'jhh114', 'jhh115', 'jhh116', 'jhh117', 'jhh118', 'jhh119', 'jhh120', 'jhh121', 'jhh122', 'jhh124', 'jhh125', 'jhh126', 'jhh127', 'jhh128', 'jhh201', 'jhh202', 'jhh203', 'jhh204', 'jhh205', 'jhh206', 'jhh207', 'jhh208', 'jhh209', 'jhh210', 'jhh211', 'jhh212', 'jhh213', 'jhh214', 'jhh215', 'jhh216', 'jhh217', 'jhh218', 'jhh219', 'jhh220', 'jhh221', 'jhh222', 'jhh223', 'jhh224', 'jhh225', 'jhh226', 'jhh227', 'jhh228', 'jhh229']
Loading data for subjects: ['jeff001', 'jeff002', 'jeff101', 'jeff102', 'jeff201', 'jeff202', 'jeff203', 'jeff20

In [150]:
display(jh_spike_dataset["data"][5])

Unnamed: 0,onset,duration,description,sample,ch_name,perception,height,n_secs
0,274.0,0.155,Spike O2-Av12 perception:0.61 height:92,54800,O2,0.61,92,1100.0
1,448.0,0.125,Spike P3-T34 perception:0.59 height:145,89600,P3,0.59,145,1100.0


# Load Previously Ran Autoreject and Filter Epochs

In [168]:
verbose = False

In [186]:
# reformat the channel spike dataframes based on autoreject logs
for dataset, deriv_root in zip(
    [jh_spike_dataset, jeff_spike_dataset], [jh_ica_root, jeff_ica_root]
):
    subjects = dataset["subject"]

    # loop through each dataset and preprocess the channel spike
    # dataframes based on the autoreject log
    for idx in range(len(subjects)):
        subject = subjects[idx]

        # get the channel spike df and bids path
        data = dataset["data"][idx]
        bids_path = dataset["bids_path"][idx]
        bids_path.update(root=deriv_root)

        if data.empty:
            if verbose:
                print(f"Skipping: {subject} with {bids_path}")
            continue

        # load in the reject log
        raw = read_raw_bids(bids_path)
        reject_log = load_reject_log(bids_path)

        # get spikes not in bad epochs
        bad_epochs = reject_log.bad_epochs
        bad_epoch_idx = np.argwhere(bad_epochs)
        n_secs = data["n_secs"][0]
        events = mne.make_fixed_length_events(
            raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0
        )
        assert len(events) == len(bad_epochs)

        # mark all spikes in bad epochs
        data["epoch_idx"] = None
        data["bad_epoch"] = None

        for jdx, sample in enumerate(data["sample"]):
            lower_idx = np.argwhere(sample <= events[:, 0])[0]
            higher_idx = np.argwhere(sample >= events[:, 0])[-1]
            epoch_idx = higher_idx[0]
            data["epoch_idx"][jdx] = epoch_idx
            in_bad_epoch = epoch_idx in bad_epoch_idx
            data["bad_epoch"][jdx] = in_bad_epoch

        # reset the new dataframe
        dataset["data"][idx] = data
    #     display(data)
    #     break

In [187]:
display(data.head())
print(subject)
print(dataset["bids_path"][idx])
print(sample)
print(raw)
print(len(data))
print(len(data[data["bad_epoch"] == True]))
# print(events)

Unnamed: 0,onset,duration,description,sample,ch_name,perception,height,n_secs,epoch_idx,bad_epoch
0,40.0,0.08,Spike T3-Av12 perception:0.26 height:66,20000,T3,0.26,66,1021.0,40,False
1,48.0,0.135,Spike T3-Av12 perception:0.31 height:80,24000,T3,0.31,80,1021.0,48,False
2,78.0,0.13,Spike T3-Av12 perception:0.15 height:68,39000,T3,0.15,68,1021.0,78,False
3,89.0,0.125,Spike T3-Av12 perception:0.4 height:59,44500,T3,0.4,59,1021.0,89,False
4,133.0,0.095,Spike T3-Av12 perception:0.11 height:83,66500,T3,0.11,83,1021.0,133,False


jeff210
/Users/adam2392/Johns Hopkins/Jefferson_Scalp - Documents/root/derivatives/ICA/1-30Hz-30/win-20/sub-jeff210/ses-1/eeg/sub-jeff210_ses-1_run-1_eeg.edf
479000
<RawEDF | sub-jeff210_ses-1_run-1_eeg.edf, 14 x 510500 (1021.0 s), ~20 kB, data not loaded>
61
0


# Create Feature Matrix from Spikes

In [188]:
# now compute spike features depending on if they're in bad
# epochs or not
dataset = deepcopy(jh_spike_dataset)
for key, item in jeff_spike_dataset.items():
    dataset[key].extend(deepcopy(item))

In [190]:
# Now, let's remove all spikes with bad epochs
for idx in range(len(dataset["subject"])):
    ch_spike_df = dataset["data"][idx]

    if ch_spike_df.empty:
        continue

    # only keep rows with bad epochs
    ch_spike_df = ch_spike_df[ch_spike_df["bad_epoch"] == False]
    dataset["data"][idx] = ch_spike_df

In [198]:
spike_features = []
for idx in range(len(dataset["subject"])):
    ch_spike_df = dataset["data"][idx]
    ch_names = dataset["ch_names"][idx]
    spike_feature_vec = spike_feature_vector(ch_spike_df, ch_names)
    spike_features.append(spike_feature_vec)

spike_features = np.array(spike_features)

print(spike_features.shape)

(125, 6)
