# Autoreject and Windowed Derivative Analysis

In [1]:
# comment if you don't have nb_black installed
%load_ext lab_black

In [46]:
import numpy as np
import scipy
import scipy.io
import pandas as pd
from pathlib import Path
import os
import collections
from natsort import natsorted
import json
import pickle
import warnings
import sys
from copy import copy, deepcopy

warnings.filterwarnings("ignore")

import mne
from mne import make_fixed_length_epochs

mne.set_log_level("ERROR")
from mne_bids import BIDSPath, get_entities_from_fname, get_entity_vals, read_raw_bids
import autoreject
from autoreject import AutoReject, read_auto_reject

from sklearn.preprocessing import OrdinalEncoder, LabelBinarizer, LabelEncoder
from sklearn.multiclass import OneVsRestClassifier
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    brier_score_loss,
    roc_curve,
    average_precision_score,
    roc_auc_score,
    f1_score,
    recall_score,
    jaccard_score,
    balanced_accuracy_score,
    accuracy_score,
    auc,
    precision_score,
    plot_precision_recall_curve,
    average_precision_score,
    precision_recall_curve,
    confusion_matrix,
    cohen_kappa_score,
    make_scorer,
    precision_recall_fscore_support,
)
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (
    StratifiedGroupKFold,
    cross_validate,
    StratifiedShuffleSplit,
    LeaveOneGroupOut,
)
from sklearn.utils import resample
import sklearn
from sklearn import preprocessing
from sklearn.pipeline import make_pipeline

import mne
from mne.time_frequency import read_tfrs

mne.set_log_level("ERROR")
from mne_bids import BIDSPath, get_entities_from_fname, get_entity_vals, read_raw_bids

from eztrack.io import read_derivative_npy

sys.path.append("../../")
from episcalp.features import spike_feature_vector
from episcalp.io.read import (
    load_persyst_spikes,
    load_reject_log,
    load_derivative_heatmaps,
    map_rejectlog_to_deriv,
)
from episcalp.utils.utils import NumpyEncoder

# if you installed sporf via README
from oblique_forests.sporf import ObliqueForestClassifier

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import re

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define Data Directories

In [7]:
jhroot = Path("/Users/adam2392/Johns Hopkins/Scalp EEG JHH - Documents/bids/")
jeffroot = Path("/Users/adam2392/Johns Hopkins/Jefferson_Scalp - Documents/root/")

# not ready yet
upmcroot = Path("/Users/adam2392/Johns Hopkins/UPMC_Scalp - Documents/")

In [24]:
jh_ica_root = jhroot / "derivatives" / "ICA" / "1-30Hz-30" / "win-20"
jeff_ica_root = jeffroot / "derivatives" / "ICA" / "1-30Hz-30" / "win-20"

In [3]:
reference = "monopolar"
radius = "1.25"

# define derivative chains
ss_deriv_chain = Path("sourcesink") / reference
frag_deriv_chain = Path("fragility") / f"radius{radius}" / reference

delta_tfr_deriv_chain = Path("tfr") / "delta"
theta_tfr_deriv_chain = Path("tfr") / "theta"
alpha_tfr_deriv_chain = Path("tfr") / "alpha"
beta_tfr_deriv_chain = Path("tfr") / "beta"

In [26]:
jh_dataset = load_derivative_heatmaps(
    jhroot / "derivatives" / frag_deriv_chain,
    search_str="*desc-perturbmatrix*.npy",
    read_func=read_derivative_npy,
    subjects=None,
    verbose=True,
)
jeff_dataset = load_derivative_heatmaps(
    jeffroot / "derivatives" / frag_deriv_chain,
    search_str="*desc-perturbmatrix*.npy",
    read_func=read_derivative_npy,
    subjects=None,
    verbose=True,
)

Loading data for subjects: ['jhh001', 'jhh002', 'jhh003', 'jhh004', 'jhh005', 'jhh006', 'jhh007', 'jhh008', 'jhh009', 'jhh010', 'jhh011', 'jhh012', 'jhh013', 'jhh014', 'jhh015', 'jhh016', 'jhh017', 'jhh018', 'jhh019', 'jhh020', 'jhh021', 'jhh022', 'jhh023', 'jhh024', 'jhh025', 'jhh026', 'jhh027', 'jhh028', 'jhh029', 'jhh030', 'jhh101', 'jhh102', 'jhh103', 'jhh104', 'jhh105', 'jhh106', 'jhh107', 'jhh108', 'jhh109', 'jhh110', 'jhh111', 'jhh112', 'jhh113', 'jhh114', 'jhh115', 'jhh116', 'jhh117', 'jhh118', 'jhh119', 'jhh120', 'jhh121', 'jhh122', 'jhh124', 'jhh125', 'jhh126', 'jhh127', 'jhh128', 'jhh201', 'jhh202', 'jhh203', 'jhh204', 'jhh205', 'jhh206', 'jhh207', 'jhh208', 'jhh209', 'jhh210', 'jhh211', 'jhh212', 'jhh213', 'jhh214', 'jhh215', 'jhh216', 'jhh217', 'jhh218', 'jhh219', 'jhh220', 'jhh221', 'jhh222', 'jhh223', 'jhh224', 'jhh225', 'jhh226', 'jhh227', 'jhh228', 'jhh229'] from /Users/adam2392/Johns Hopkins/Scalp EEG JHH - Documents/bids/derivatives/fragility/radius1.25/monopolar
Loa

In [18]:
print(jh_dataset["data"][0].info)

<DerivativeInfo | 17 non-empty values
 DerivativeFileName: sub-jhh001_run-01_desc-perturbmatrix_eeg.npy
 ch_axis: 1 item (list)
 ch_names: Fp1, Fp2, F3, F4, P3, P4, O1, O2, F7, F8, T3, T4, T5, T6
 chs: 14 EEG
 datatype: eeg
 description: perturbmatrix
 meas_date: 2021-10-05 09:37:34 UTC
 model_parameters: 3 items (dict)
 nchan: 14
 rawsources: 1 item (list)
 reference: monopolar
 root: D:/OneDriveParent/OneDrive - Johns Hopkins/Shared ...
 sfreq: 0.8 Hz
 source_entities: 9 items (dict)
 source_info: 14 items (dict)
 sources: 1 item (list)
 t_axis: 1 item (list)
>


# Load Previously Ran Autoreject and Filter Epochs

In [8]:
verbose = False

In [51]:
# reformat the channel spike dataframes based on autoreject logs
for dataset, deriv_root in zip(
    [jh_dataset, jeff_dataset], [jh_ica_root, jeff_ica_root]
):
    subjects = dataset["subject"]

    # loop through each dataset and preprocess the channel spike
    # dataframes based on the autoreject log
    for idx in range(len(subjects)):
        subject = subjects[idx]

        # get the channel spike df and bids path
        data = dataset["data"][idx]
        bids_path = dataset["bids_path"][idx]
        bids_path.update(root=deriv_root)

        # load in the reject log
        raw = read_raw_bids(bids_path)
        reject_log = load_reject_log(bids_path)

        # get spikes not in bad epochs
        bad_epochs = reject_log.bad_epochs
        bad_epoch_idx = np.argwhere(bad_epochs)
        events = mne.make_fixed_length_events(
            raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0
        )

        bad_events = events[bad_epoch_idx, :]
        assert len(events) == len(bad_epochs)

        deriv = jh_dataset["data"][0]
        winsize = 500
        n_windows = deriv.shape[1]
        deriv_onsets = np.arange(n_windows) * winsize
        duration = 200

        # get the bad window indices
        bad_win_index = map_rejectlog_to_deriv(
            deriv_onsets,
            winsize,
            rejectlog_events=bad_events.squeeze(),
            rejectlog_duration=duration,
        )

        break
    break

In [52]:
print(bad_win_index)
print(bad_events.squeeze())
print(500 * 133)

[132, 133, 223, 224, 256, 257, 290, 291]
[[ 66400      0      1]
 [112000      0      1]
 [128400      0      1]
 [129200      0      1]
 [145400      0      1]
 [149200      0      1]]
66500
