In [23]:
import numpy as np
import scipy
import scipy.io
import pandas as pd
from pathlib import Path
import os
import collections
import json
import warnings
import sys
from numpy import interp
from pprint import pprint
import joblib

from numpy.testing import assert_array_equal

warnings.filterwarnings("ignore")

from sklearn.preprocessing import OrdinalEncoder, LabelBinarizer, LabelEncoder
from sklearn.multiclass import OneVsRestClassifier
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    brier_score_loss,
    roc_curve,
    average_precision_score,
    roc_auc_score,
    f1_score,
    recall_score,
    jaccard_score,
    balanced_accuracy_score,
    accuracy_score,
    auc,
    precision_score,
    plot_precision_recall_curve,
    average_precision_score,
    precision_recall_curve,
    confusion_matrix,
    cohen_kappa_score,
    make_scorer,
    precision_recall_fscore_support
)
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (
    #StratifiedGroupKFold,
    cross_validate,
    StratifiedShuffleSplit,
    LeaveOneGroupOut,
)
from sklearn.utils import resample
import sklearn
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

import mne
from mne.time_frequency import read_tfrs

mne.set_log_level("ERROR")

sys.path.append("../../")
from episcalp.features import spike_feature_vector
from episcalp.io.read import load_anil_spikes
from episcalp.preprocess.montage import _standard_lobes
from episcalp.utils.utils import NumpyEncoder
from episcalp.cross_validate import exclude_subjects

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import re

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
def _get_exp_condition(subject, root):
    part_fname = os.path.join(root, "participants.tsv")
    df = pd.read_csv(part_fname, sep="\t")

    if not subject.startswith("sub-"):
        subject = f"sub-{subject}"

    return df[df["participant_id"] == subject]

In [25]:
def convert_experimental_cond_to_y(experimental_condition_list):
    """Encoder for y labels."""
    # Group name keys, assigned y-label values
    experimental_condition_map = {
        "non-epilepsy-normal-eeg": 0,
        "epilepsy-normal-eeg": 0,
        "epilepsy-abnormal-eeg": 1,
    }
    return [experimental_condition_map[cond] for cond in experimental_condition_list]

In [26]:
def combine_datasets(deriv_dataset):
    dataset = deriv_dataset[0]
    for deriv in deriv_dataset:
        for key in deriv.keys():
            if key not in dataset.keys():
                raise RuntimeError(
                    f"All keys in {dataset.keys()} must match every other derived dataset. "
                    f"{key}, {deriv.keys()}."
                )

    # convert to a dictionary of lists
    derived_dataset = {key: [] for key in dataset.keys()}
    for deriv in deriv_dataset:
        for key in derived_dataset.keys():
            derived_dataset[key].extend(deriv[key])
    return derived_dataset

# Define Data Directories

In [27]:
user = "patrick"
if user == "patrick":
    jhroot = Path("D:/OneDriveParent/OneDrive - Johns Hopkins/Shared Documents/bids")
    jeffroot = Path("D:/OneDriveParent/Johns Hopkins/Jefferson_Scalp - Documents/root")

    # not ready yet
    upmcroot = Path("D:/OneDriveParent/Johns Hopkins/UPMC_Scalp - Documents/scalp_study/root")
elif user == "adam":
    jhroot = Path("/Users/adam2392/Johns Hopkins/Scalp EEG JHH - Documents/bids/")
    jeffroot = Path("/Users/adam2392/Johns Hopkins/Jefferson_Scalp - Documents/root/")

    # not ready yet
    upmcroot = Path("/Users/adam2392/Johns Hopkins/UPMC_Scalp - Documents/")
elif user == "kristin":
    jhroot = Path()
    jeffroot = Path()
    upmcroot = Path()

In [28]:
bids_roots = [jhroot, jeffroot, upmcroot]
roots = [jhroot, jeffroot, upmcroot]

In [29]:
if user == "adam":
    deriv_dir = Path(
        "/Users/adam2392/Johns Hopkins/Scalp EEG JHH - Documents/derivatives"
    )
elif user == "patrick":
    deriv_dir = Path(
        "D:/OneDriveParent/OneDrive - Johns Hopkins/Shared Documents/derivatives"
    )
elif user == "kristin":
    deriv_dir = Path()

In [32]:
anil_spikes_deriv_chain = Path("spikes_anil")
spike_thresh = np.arange(-1.3, -0.2, 0.1)
spike_thresh = [np.round(st, 1) for st in spike_thresh]
deriv_chains = [anil_spikes_deriv_chain / f"thresh-{st}" for st in spike_thresh]

In [33]:
deriv_chains

[WindowsPath('spikes_anil/thresh--1.3'),
 WindowsPath('spikes_anil/thresh--1.2'),
 WindowsPath('spikes_anil/thresh--1.1'),
 WindowsPath('spikes_anil/thresh--1.0'),
 WindowsPath('spikes_anil/thresh--0.9'),
 WindowsPath('spikes_anil/thresh--0.8'),
 WindowsPath('spikes_anil/thresh--0.7'),
 WindowsPath('spikes_anil/thresh--0.6'),
 WindowsPath('spikes_anil/thresh--0.5'),
 WindowsPath('spikes_anil/thresh--0.4'),
 WindowsPath('spikes_anil/thresh--0.3')]

# Load the Data (Once)

In [43]:
spike_datasets = {}

kwargs = {}
for tind, thresh in enumerate(spike_thresh):
    datasets = []
    for root in bids_roots:
        print(f"Loading spike data for {root}")
        dataset = load_anil_spikes(
            root / "derivatives" / deriv_chains[tind],
            search_str="*.json",
            subjects=None,
            verbose=True,
            **kwargs
        )
        datasets.append(dataset)
    spike_dataset = combine_datasets(datasets)
    spike_datasets[str(thresh)] = spike_dataset
print(len(dataset["subject"]))

Loading spike data for D:\OneDriveParent\OneDrive - Johns Hopkins\Shared Documents\bids
Loading data for subjects: ['jhh001', 'jhh002', 'jhh003', 'jhh004', 'jhh005', 'jhh006', 'jhh007', 'jhh008', 'jhh009', 'jhh010', 'jhh011', 'jhh012', 'jhh013', 'jhh014', 'jhh015', 'jhh016', 'jhh017', 'jhh018', 'jhh019', 'jhh020', 'jhh021', 'jhh022', 'jhh023', 'jhh024', 'jhh025', 'jhh026', 'jhh027', 'jhh028', 'jhh029', 'jhh030', 'jhh031', 'jhh032', 'jhh033', 'jhh034', 'jhh035', 'jhh036', 'jhh037', 'jhh038', 'jhh101', 'jhh102', 'jhh103', 'jhh104', 'jhh105', 'jhh106', 'jhh107', 'jhh108', 'jhh109', 'jhh110', 'jhh111', 'jhh112', 'jhh113', 'jhh114', 'jhh115', 'jhh116', 'jhh117', 'jhh118', 'jhh119', 'jhh120', 'jhh121', 'jhh122', 'jhh124', 'jhh125', 'jhh126', 'jhh127', 'jhh128', 'jhh129', 'jhh130', 'jhh131', 'jhh132', 'jhh133', 'jhh134', 'jhh135', 'jhh136', 'jhh137', 'jhh138', 'jhh139', 'jhh140', 'jhh141', 'jhh142', 'jhh143', 'jhh144', 'jhh201', 'jhh202', 'jhh204', 'jhh205', 'jhh206', 'jhh207', 'jhh209', 'jhh

Loading data for subjects: []
Loading spike data for D:\OneDriveParent\OneDrive - Johns Hopkins\Shared Documents\bids
Loading data for subjects: ['jhh001', 'jhh002', 'jhh003', 'jhh004', 'jhh005', 'jhh006', 'jhh007', 'jhh008', 'jhh009', 'jhh010', 'jhh011', 'jhh012', 'jhh013', 'jhh014', 'jhh015', 'jhh016', 'jhh017', 'jhh018', 'jhh019', 'jhh020', 'jhh021', 'jhh022', 'jhh023', 'jhh024', 'jhh025', 'jhh026', 'jhh027', 'jhh028', 'jhh029', 'jhh030', 'jhh031', 'jhh032', 'jhh033', 'jhh034', 'jhh035', 'jhh036', 'jhh037', 'jhh038', 'jhh101', 'jhh102', 'jhh103', 'jhh104', 'jhh105', 'jhh106', 'jhh107', 'jhh108', 'jhh109', 'jhh110', 'jhh111', 'jhh112', 'jhh113', 'jhh114', 'jhh115', 'jhh116', 'jhh117', 'jhh118', 'jhh119', 'jhh120', 'jhh121', 'jhh122', 'jhh124', 'jhh125', 'jhh126', 'jhh127', 'jhh128', 'jhh129', 'jhh130', 'jhh131', 'jhh132', 'jhh133', 'jhh134', 'jhh135', 'jhh136', 'jhh137', 'jhh138', 'jhh139', 'jhh140', 'jhh141', 'jhh142', 'jhh143', 'jhh144', 'jhh201', 'jhh202', 'jhh204', 'jhh205', 'jhh

In [45]:
len(spike_datasets["-1.3"]['subject'])

283

In [41]:
spike_dataset["subject"]

['jhh001',
 'jhh002',
 'jhh003',
 'jhh004',
 'jhh005',
 'jhh006',
 'jhh007',
 'jhh008',
 'jhh009',
 'jhh010',
 'jhh011',
 'jhh012',
 'jhh013',
 'jhh014',
 'jhh015',
 'jhh016',
 'jhh017',
 'jhh018',
 'jhh019',
 'jhh020',
 'jhh021',
 'jhh022',
 'jhh023',
 'jhh024',
 'jhh025',
 'jhh026',
 'jhh027',
 'jhh027',
 'jhh027',
 'jhh027',
 'jhh028',
 'jhh028',
 'jhh028',
 'jhh028',
 'jhh029',
 'jhh030',
 'jhh031',
 'jhh032',
 'jhh033',
 'jhh034',
 'jhh035',
 'jhh036',
 'jhh037',
 'jhh038',
 'jhh101',
 'jhh102',
 'jhh103',
 'jhh104',
 'jhh105',
 'jhh106',
 'jhh107',
 'jhh108',
 'jhh109',
 'jhh110',
 'jhh111',
 'jhh112',
 'jhh113',
 'jhh114',
 'jhh115',
 'jhh116',
 'jhh117',
 'jhh118',
 'jhh119',
 'jhh120',
 'jhh121',
 'jhh122',
 'jhh124',
 'jhh124',
 'jhh124',
 'jhh124',
 'jhh125',
 'jhh125',
 'jhh126',
 'jhh127',
 'jhh127',
 'jhh127',
 'jhh127',
 'jhh128',
 'jhh129',
 'jhh130',
 'jhh131',
 'jhh132',
 'jhh133',
 'jhh134',
 'jhh135',
 'jhh136',
 'jhh137',
 'jhh138',
 'jhh139',
 'jhh140',
 'jhh141',

# Define Run parameters

In [21]:
metric_names = [f"spikes_thresh-{thresh}" for thresh in spike_thresh]
metric_mapping = {}
[metric_mapping.update({f"spikes_thresh-{thresh}": spike_datasets[str(thresh)]}) for thresh in spike_thresh]
custom_feats = []
custom_feat_names = []


KeyError: '-1.3'

In [22]:
metric_mapping

{'spikes_thresh--1.5': {'data': [{'Fp1': 0.0,
    'Fp2': 0.0,
    'F3': 0.0010193679918450561,
    'F4': 0.0010193679918450561,
    'C3': 0.004077471967380225,
    'C4': 0.0,
    'P3': 0.0,
    'P4': 0.0010193679918450561,
    'O1': 0.0010193679918450561,
    'O2': 0.004077471967380225,
    'F7': 0.004077471967380225,
    'F8': 0.0,
    'T3': 0.0050968399592252805,
    'T4': 0.0010193679918450561,
    'T5': 0.0020387359836901123,
    'T6': 0.0020387359836901123,
    'Fz': 0.0,
    'Cz': 0.0010193679918450561,
    'Pz': 0.0020387359836901123,
    'E': 0.0,
    'F9': 0.0,
    'F10': 0.0,
    'M1': 0.0,
    'M2': 0.0,
    'PO7': 0.0,
    'PO8': 0.0,
    'X1': 0.0010193679918450561,
    'X2': 0.0050968399592252805,
    'X3': 0.0061162079510703364,
    'X4': 0.004077471967380225,
    'X5': 0.004077471967380225,
    'X6': 0.0061162079510703364,
    'X7': 0.011213047910295617,
    'SpO2': 0.0,
    'EtCO2': 0.0,
    'DC03': 0.0,
    'DC04': 0.0,
    'DC05': 0.0,
    'DC06': 0.0,
    'Pulse': 0

In [None]:
categorical_exclusion_criteria = {
    "exp_condition": None,
    "final_diagnosis": None,
    "epilepsy_type": ["generalized"],
    "epilepsy_hemisphere": None,
    "epilepsy_lobe": None,
}

continuous_exclusion_criteria = {
    "age": None,
    "num_aeds": None,
}

train_size = 0.7
random_state = 12345

clf_name = "lr"

lr_model_params = {
    "n_jobs": -1,
    "random_state": random_state,
    "penalty": "l1",
    "solver": "liblinear"
}
model_params = {
    "lr": lr_model_params
}

exp_name = "anil_spike_detection_feats"

# Run Experiment in a Loop

In [7]:
scaler = StandardScaler()
y_enc = LabelBinarizer()

In [8]:
model_names = {"lr": "logisticregression"}

In [None]:
scoring_funcs = {
    "balanced_accuracy": make_scorer(balanced_accuracy_score),
    "cohen_kappa_score": make_scorer(cohen_kappa_score),
    "roc_auc": "roc_auc",  # roc_auc_score,
    "f1": "f1",  # f1_score,
    "recall": "recall",  # makerecall_score,
    "precision": "precision",  # precision_score,
    "jaccard": "jaccard",  # jaccard_score,
    "average_precision": "average_precision",  # average_precision_score,
    "neg_brier_score": "neg_brier_score",  # brier_score_loss,
}

scoring = scoring_funcs
print(scoring)

In [None]:
fname = deriv_dir / "spikes" / clf_name / f"{exp_name}_features.csv"
fname.parent.mkdir(exist_ok=True, parents=True)

print(f"File {fname} exists {fname.exists()}")

In [12]:
idx = 0
dfs = []
for mind, metric_name in metric_names:
    print(f"Using metric: {metric_name}")
    features = []
    for idx in range(len(spike_dataset["subject"])):
        dataset = metric_mapping[metric_name].copy()
        data = dataset["data"][idx]
        ch_names = dataset["ch_names"][idx]
        feature_vec = spike_feature_vect_dict(data, ch_names=ch_names)
        features.append(feature_vec)
    
    features = np.array(features)
    print(f"Features has shape: {features}")
    
    dataset = spike_dataset
    
    # get the y labels
    subjects = np.array(dataset["subject"])
    roots = dataset["roots"]
    
    exp_conditions = []
    for subject, root in zip(subjects, roots):
        subj_df = _get_exp_condition(subject, root)
        exp_condition = subj_df["exp_condition"].values[0]
        exp_conditions.append(exp_condition)
    
    # encode the y label
    y = np.array(convert_experimental_cond_to_y(np.array(exp_conditions)))
    X = features
    
    # Further subset the subjects if desired
    X, y, keep_subjects = exclude_subjects(
        X, y, subjects, bids_roots, categorical_exclusion_criteria, continuous_exclusion_criteria
    )

    print(X.shape, y.shape, keep_subjects.shape)
    
    steps = []
    if clf_name == "lr":
        clf = LogisticRegression(**lr_model_params)
        steps.append(scaler)
    
    steps.append(clf)
    clf = make_pipeline(*steps)
    
    clf.fit(X, y)
    
    scoring_funcs = {
        "balanced_accuracy": balanced_accuracy_score,
        "cohen_kappa_score": cohen_kappa_score,
        "roc_auc": roc_auc_score,  #  "roc_auc",  # roc_auc_score,
        "f1": f1_score,
        "recall": recall_score,
        "specificity": recall_score,
        "precision": precision_score,
        "jaccard": jaccard_score,
        "average_precision": average_precision_score,
        "neg_brier_score": brier_score_loss,
        "cohen_kappa_score": cohen_kappa_score,
        #     'specificity': '',
    }

    # evaluate the model performance
    train_scores = dict()
    for score_name, score_func in scoring_funcs.items():
        y_pred_proba = clf.predict_proba(X)
        if score_name == "specificity":
            score_func = make_scorer(score_func, pos_label=0)
        else:
            score_func = make_scorer(score_func)
        score = score_func(clf, X, y)

        train_scores[score_name] = score

    for idx in np.unique(y):
        print(f"Class {idx} has ", len(np.argwhere(y == idx)))
    y_pred = clf.predict(X)

    cv_scores = {}
    for cv_name, cv in cvs.items():
        # run cross-validation
        scores = cross_validate(
            clf,
            X,
            y,
            groups=keep_subjects,
            cv=cv,
            scoring=scoring,
            return_estimator=True,
            return_train_score=False,
            n_jobs=-1,
            error_score="raise",
        )

        # get the estimators
        estimators = scores.pop("estimator")
        cv_scores[cv_name] = scores

    
    for idx in np.unique(y):
        print(f"Class {idx} has {len(np.argwhere(y == idx))}")
    y_pred = clf.predict(X)

    result_df = pd.DataFrame()
    idx = 0
    
    result_df["exp"] = ""
    result_df.at[1, "exp"] = idx
    # Changed heatmaps to metrics to account for spikes and other non-heatmap metrics
    result_df["metrics"] = ""
    result_df.at[1, "metrics"] = metric_name
    result_df["data_shape"] = str(X.shape)
    result_df["clf"] = clf_name
    result_df["predictions"] = ""
    result_df.at[1, "predictions"] = "["+", ".join(y_pred) + "]"
    result_df["labels"] = ""
    result_df.at[1, "labels"] = "["+", ".join(y)+"]"
    
    for name, score in train_scores.items():
        result_df[f"train_{name}"] = score

    for name, scores in cv_scores.items():
        for metric, score in scores.items():
            if not metric.startswith("test_"):
                continue

            result_df[f"{name}_{metric}"] = ""
            result_df.at[1, f"{name}_{metric}"] = score
            result_df[f"{name}_{metric}_avg"] = np.mean(score)
                result_df[f"{name}_{metric}_std"] = np.std(score)
    
    dfs.append(result_df)
result_df = pd.concat(dfs)
display(result_df)
result_df.to_csv(fname, index=None)

SyntaxError: invalid syntax (<ipython-input-12-9c9fe391fadb>, line 65)

# Plotting Results

In [None]:
clf_name = "lr"
fname = deriv_dir / "normaleeg" / clf_name / f"{exp_name}_features.csv"
print(fname)

result_df = pd.read_csv(fname, index_col=None)

In [None]:
display(result_df.head())

In [None]:
fig, ax = plt.subplots()

y = "train_roc_auc"
x = np.arange(len(result_df))
ax.plot(x, result_df[y], "*")
ax.set(title=f"{clf_name} ", xlabel="Exp indices", ylabel=y)
ax.axhline([0.5], ls="--")

In [None]:
fig, ax = plt.subplots()

y = "stratifiedshuffle_test_roc_auc_avg"
x = np.arange(len(result_df))
ax.plot(x, result_df[y], "*")
ax.set(title=f"{clf_name} ", xlabel="Exp indices", ylabel=y)
ax.axhline([0.5], ls="--")