In [4]:
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif, SelectFromModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, LeaveOneGroupOut
from sklearn.metrics import make_scorer, accuracy_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.feature_selection import VarianceThreshold
from tqdm import tqdm

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from gait_modulation import FeatureExtractor2
from gait_modulation.utils.utils import *
from itertools import product

In [5]:
# Load the preprocessed data
patient_epochs = load_pkl('results/pickles/patients_epochs.pickle')
subjects_event_idx_dict = load_pkl('results/pickles/subjects_event_idx_dict.pickle')
sfreq = patient_epochs['PW_EM59'].info['sfreq']

patient_names = np.array(list(patient_epochs.keys()))


In [6]:
# configuration for feature extraction
features_config = {
    'time_features': {
        # 'mean': True,
        # 'std': True,
        # 'median': True,
        # 'skew': True,
        # 'kurtosis': True,
        # 'rms': True
            # peak_to_peak = np.ptp(lfp_data, axis=2)
    },
    'freq_features': {
        'psd_raw': True,
            # psd_vals = np.abs(np.fft.rfft(lfp_data, axis=2))
        # 'psd_band_mean': True, band power!
        # 'psd_band_std': True,
        # 'spectral_entropy': True
    },
    # 'wavelet_features': {
    #     'energy': False
    # },
    # 'nonlinear_features': {
    #     'sample_entropy': True,
    #     'hurst_exponent': False
    # }
}


# Initialize the FeatureExtractor
feature_extractor = FeatureExtractor2(sfreq, features_config)

feature_handling = "flatten_chs"

# Extract features
feature_matrix, feature_idx_map = feature_extractor.extract_features(
    patient_epochs['PW_FH57'], feature_handling)

# Print the shape of the extracted features
print("Extracted features shape:", feature_matrix.shape)

# feature_extractor.select_feature(feature_matrix, 'freq_features_beta_psd_raw', feature_handling="flatten_chs").shape
# freq_bands = {
#     "delta": (0.5, 4),
#     "theta": (4, 8),
#     "alpha": (8, 12),
#     "beta": (20, 30),
#     "gamma": (30, 100)
# }

feature_matrix.shape, feature_idx_map

Extracted features shape: (1307, 294)


((1307, 294), {'freq_features_all_psd_raw': (0, 294)})

In [44]:
X, y, groups = [], [], []
for patient in patient_names:
    epochs = patient_epochs[patient]
    X_patient, y_patient = feature_extractor.extract_features_with_labels(
        epochs, feature_handling="flatten_chs"
    )
    X.append(X_patient)
    y.append(y_patient)
    groups.extend([patient] * len(y_patient))

X = np.concatenate(X, axis=0)
y = np.concatenate(y, axis=0)
assert len(X) == len(y) == len(groups), "Mismatch in lengths of X, y, and groups."
print(f"X shape: {X.shape}, y shape: {y.shape}, groups length: {len(groups)}")

# feature_selection_methods = {
#     'select_k_best': SelectKBest(score_func=f_classif),
#     'pca': PCA(),
#     'model_based': SelectFromModel(RandomForestClassifier(n_estimators=100))
# }

# Define candidate models for classification
models = {
    'logistic_regression': LogisticRegression(),
    'svm': SVC(probability=True),  # Enable predict_proba for SVM
    'random_forest': RandomForestClassifier(),
}

# Build a pipeline with placeholders for feature selection and classifier
# Remove constant features before feature selection (Remove features with zero variance)
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('variance_threshold', VarianceThreshold(threshold=0.0)),
    ('feature_selection', 'passthrough'),
    ('classifier', 'passthrough')
])

# Define parameter grid as a list of dictionaries
n_features = X.shape[1]

param_grid = [
    {
        # 'feature_selection': [feature_selection_methods['select_k_best']],
        # 'feature_selection__k': [min(n_features, 30)],  # Avoid 'all' if not feasible
        'classifier': [models['logistic_regression']],
        'classifier__C': [0.1],
        'classifier__penalty': ['l2']
    },
    {
        # 'feature_selection': [feature_selection_methods['pca']],
        # 'feature_selection__n_components': [min(n_features, n) for n in [5]],
        'classifier': [models['svm']],
        'classifier__C': [0.1],
        'classifier__kernel': ['linear', 'rbf']
    },
    {
        # 'feature_selection': [feature_selection_methods['model_based']],
        'classifier': [models['random_forest']],
        'classifier__n_estimators': [50],
        'classifier__max_depth': [5],
        'classifier__min_samples_split': [2]
    },
]

# Define scoring metrics
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'f1': make_scorer(f1_score, average='weighted'),
}

# Add roc_auc only for models supporting predict_proba
if any(hasattr(ml_model, "predict_proba") for ml_model in models.values()):
    scoring['roc_auc'] = make_scorer(roc_auc_score, needs_proba=True, multi_class='ovr')

logo = LeaveOneGroupOut()
        
# Estimate total fits: n_splits * n_params
n_splits = logo.get_n_splits(X, y, groups)
param_values = param_grid[0].values()
candidates = list(product(*param_values))
n_candidates = len(candidates)
total_fits = n_splits * n_candidates
print(f"Fitting {n_splits} folds for each of {n_candidates} candidates, totalling {total_fits} fits")

ml_grid_search = GridSearchCV(
    pipeline,
    param_grid=param_grid,
    cv=logo,
    scoring=scoring,
    refit='f1' if 'f1' in scoring else 'accuracy',
    n_jobs=-1,
    verbose=3
)
ml_grid_search.fit(X, y, groups=groups)

X shape: (6083, 294), y shape: (6083,), groups length: 6083
Fitting 7 folds for each of 1 candidates, totalling 7 fits
Fitting 7 folds for each of 4 candidates, totalling 28 fits


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

[CV 4/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.558) f1: (test=0.560) roc_auc: (test=0.606) total time=   0.5s
[CV 2/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.627) f1: (test=0.523) roc_auc: (test=0.566) total time=   0.4s
[CV 5/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.639) f1: (test=0.541) roc_auc: (test=0.446) total time=   0.4s
[CV 1/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.788) f1: (test=0.744) roc_auc: (test=0.513) total time=   0.5s


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

[CV 3/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.693) f1: (test=0.617) roc_auc: (test=0.486) total time=   0.5s
[CV 6/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.376) f1: (test=0.396) roc_auc: (test=0.459) total time=   0.6s
[CV 7/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2; accuracy: (test=0.590) f1: (test=0.458) roc_auc: (test=0.573) total time=   0.5s
[CV 5/7] END classifier=SVC(probability=True), classifier__C=0.1, classifier__kernel=linear; accuracy: (test=0.652) f1: (test=0.528) roc_auc: (test=0.454) total time= 1.0min
[CV 2/7] END classifier=SVC(probability=True), classifier__C=0.1, classifier__kernel=linear; accuracy: (test=0.621) f1: (test=0.508) roc_auc: (test=0.562) total time= 1.0min
[CV 3/7] END classifier=SVC(probability=True), classifier__C=0.1, classifier__kernel=linear; accuracy: (test=0.701) f1: (test=0.604) roc_auc: 

In [45]:
# Convert GridSearchCV results into a DataFrame
results_df = pd.DataFrame(ml_grid_search.cv_results_)
results_df
# results_df.T

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_classifier,param_classifier__C,param_classifier__penalty,param_classifier__kernel,param_classifier__max_depth,param_classifier__min_samples_split,...,split0_test_roc_auc,split1_test_roc_auc,split2_test_roc_auc,split3_test_roc_auc,split4_test_roc_auc,split5_test_roc_auc,split6_test_roc_auc,mean_test_roc_auc,std_test_roc_auc,rank_test_roc_auc
0,0.488288,0.053589,0.015921,0.007275,LogisticRegression(),0.1,l2,,,,...,0.513393,0.565704,0.485865,0.60607,0.446247,0.459151,0.573007,0.521348,0.056877,3
1,79.76092,18.070738,1.776426,0.668136,SVC(probability=True),0.1,,linear,,,...,0.511278,0.561797,0.490143,0.591351,0.454171,0.45156,0.556582,0.516698,0.050829,4
2,77.923221,11.742475,2.724574,1.367138,SVC(probability=True),0.1,,rbf,,,...,0.542293,0.558585,0.497294,0.640968,0.560947,0.501381,0.521092,0.54608,0.045352,2
3,4.46265,0.365126,0.046344,0.022316,RandomForestClassifier(),,,,5.0,2.0,...,0.43562,0.609424,0.53795,0.678118,0.520663,0.541688,0.535998,0.551352,0.069986,1


In [46]:
results_df.columns

Index(['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time',
       'param_classifier', 'param_classifier__C', 'param_classifier__penalty',
       'param_classifier__kernel', 'param_classifier__max_depth',
       'param_classifier__min_samples_split', 'param_classifier__n_estimators',
       'params', 'split0_test_accuracy', 'split1_test_accuracy',
       'split2_test_accuracy', 'split3_test_accuracy', 'split4_test_accuracy',
       'split5_test_accuracy', 'split6_test_accuracy', 'mean_test_accuracy',
       'std_test_accuracy', 'rank_test_accuracy', 'split0_test_f1',
       'split1_test_f1', 'split2_test_f1', 'split3_test_f1', 'split4_test_f1',
       'split5_test_f1', 'split6_test_f1', 'mean_test_f1', 'std_test_f1',
       'rank_test_f1', 'split0_test_roc_auc', 'split1_test_roc_auc',
       'split2_test_roc_auc', 'split3_test_roc_auc', 'split4_test_roc_auc',
       'split5_test_roc_auc', 'split6_test_roc_auc', 'mean_test_roc_auc',
       'std_test_roc_auc', 'rank_test

In [47]:
results_df[
    ['param_classifier', 
     'split0_test_accuracy', 'split1_test_accuracy', 'split2_test_accuracy', 'split3_test_accuracy', 'split4_test_accuracy', 'split5_test_accuracy', 'split6_test_accuracy', 
     'mean_test_accuracy', 'std_test_accuracy', 
     'rank_test_accuracy',]
].sort_values(by='rank_test_accuracy', ascending=False).T

Unnamed: 0,0,1,2,3
param_classifier,LogisticRegression(),SVC(probability=True),SVC(probability=True),RandomForestClassifier()
split0_test_accuracy,0.787879,0.787879,0.806061,0.8
split1_test_accuracy,0.627391,0.620505,0.632747,0.632747
split2_test_accuracy,0.693193,0.700863,0.717162,0.717162
split3_test_accuracy,0.558184,0.578997,0.613056,0.617786
split4_test_accuracy,0.638743,0.652206,0.647719,0.649215
split5_test_accuracy,0.375969,0.436047,0.757752,0.757752
split6_test_accuracy,0.589666,0.585106,0.582067,0.582067
mean_test_accuracy,0.610146,0.623086,0.679509,0.679533
std_test_accuracy,0.118055,0.101721,0.076217,0.07413


In [48]:
results_df[
    ['param_classifier',
     'split0_test_f1', 'split1_test_f1', 'split2_test_f1', 'split3_test_f1', 'split4_test_f1', 'split5_test_f1', 'split6_test_f1',
     'mean_test_f1', 'std_test_f1',
     'rank_test_f1']
].sort_values(by='rank_test_f1', ascending=False).T

Unnamed: 0,0,1,2,3
param_classifier,LogisticRegression(),SVC(probability=True),SVC(probability=True),RandomForestClassifier()
split0_test_f1,0.743842,0.729003,0.719504,0.716498
split1_test_f1,0.522895,0.507995,0.490423,0.490423
split2_test_f1,0.616862,0.603681,0.599037,0.599037
split3_test_f1,0.560225,0.579278,0.465994,0.517965
split4_test_f1,0.540849,0.528403,0.509237,0.515357
split5_test_f1,0.396386,0.471662,0.653321,0.666963
split6_test_f1,0.457714,0.43782,0.428303,0.428303
mean_test_f1,0.548396,0.55112,0.55226,0.562078
std_test_f1,0.103554,0.090019,0.099178,0.095148


In [49]:
results_df[
    ['param_classifier',
     'split0_test_roc_auc', 'split1_test_roc_auc', 'split2_test_roc_auc', 'split3_test_roc_auc', 'split4_test_roc_auc', 'split5_test_roc_auc', 'split6_test_roc_auc',
     'mean_test_roc_auc', 'std_test_roc_auc',
     'rank_test_roc_auc']
].sort_values(by='rank_test_roc_auc', ascending=False).T

Unnamed: 0,1,0,2,3
param_classifier,SVC(probability=True),LogisticRegression(),SVC(probability=True),RandomForestClassifier()
split0_test_roc_auc,0.511278,0.513393,0.542293,0.43562
split1_test_roc_auc,0.561797,0.565704,0.558585,0.609424
split2_test_roc_auc,0.490143,0.485865,0.497294,0.53795
split3_test_roc_auc,0.591351,0.60607,0.640968,0.678118
split4_test_roc_auc,0.454171,0.446247,0.560947,0.520663
split5_test_roc_auc,0.45156,0.459151,0.501381,0.541688
split6_test_roc_auc,0.556582,0.573007,0.521092,0.535998
mean_test_roc_auc,0.516698,0.521348,0.54608,0.551352
std_test_roc_auc,0.050829,0.056877,0.045352,0.069986
