In [3]:
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif, SelectFromModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, LeaveOneGroupOut
from sklearn.metrics import make_scorer, accuracy_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.feature_selection import VarianceThreshold
from tqdm import tqdm

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from gait_modulation import FeatureExtractor2
from gait_modulation.utils.utils import *

In [None]:
# Load the preprocessed data
patient_epochs = load_pkl('results/pickles/patients_epochs.pickle')
subjects_event_idx_dict = load_pkl('results/pickles/subjects_event_idx_dict.pickle')
sfreq = patient_epochs['PW_EM59'].info['sfreq']

patient_names = np.array(list(patient_epochs.keys()))


In [5]:
# configuration for feature extraction
features_config = {
    'time_features': {
        # 'mean': True,
        # 'std': True,
        # 'median': True,
        # 'skew': True,
        # 'kurtosis': True,
        # 'rms': True
            # peak_to_peak = np.ptp(lfp_data, axis=2)
    },
    'freq_features': {
        'psd_raw': True,
            # psd_vals = np.abs(np.fft.rfft(lfp_data, axis=2))
        # 'psd_band_mean': True, band power!
        # 'psd_band_std': True,
        # 'spectral_entropy': True
    },
    # 'wavelet_features': {
    #     'energy': False
    # },
    # 'nonlinear_features': {
    #     'sample_entropy': True,
    #     'hurst_exponent': False
    # }
}


# Initialize the FeatureExtractor
feature_extractor = FeatureExtractor2(sfreq, features_config)

feature_handling = "flatten_chs"

# Extract features
feature_matrix, feature_idx_map = feature_extractor.extract_features(
    patient_epochs['PW_FH57'], feature_handling)

# Print the shape of the extracted features
print("Extracted features shape:", feature_matrix.shape)

# feature_extractor.select_feature(feature_matrix, 'freq_features_beta_psd_raw', feature_handling="flatten_chs").shape
# freq_bands = {
#     "delta": (0.5, 4),
#     "theta": (4, 8),
#     "alpha": (8, 12),
#     "beta": (20, 30),
#     "gamma": (30, 100)
# }

feature_matrix.shape, feature_idx_map

Extracted features shape: (1307, 294)


((1307, 294), {'freq_features_all_psd_raw': (0, 294)})

In [None]:
X, y, groups = [], [], []
for patient in patient_names:
    epochs = patient_epochs[patient]
    X_patient, y_patient = feature_extractor.extract_features_with_labels(
        epochs, feature_handling="flatten_chs"
    )
    X.append(X_patient)
    y.append(y_patient)
    groups.extend([patient] * len(y_patient))

X = np.concatenate(X, axis=0)
y = np.concatenate(y, axis=0)
assert len(X) == len(y) == len(groups), "Mismatch in lengths of X, y, and groups."


feature_selection_methods = {
    'select_k_best': SelectKBest(score_func=f_classif),
    'pca': PCA(),
    'model_based': SelectFromModel(RandomForestClassifier(n_estimators=100))
}

# Define candidate models for classification
models = {
    'logistic_regression': LogisticRegression(),
    'svm': SVC(probability=True),  # Enable predict_proba for SVM
    'random_forest': RandomForestClassifier(),
}

# Build a pipeline with placeholders for feature selection and classifier
# Remove constant features before feature selection (Remove features with zero variance)
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('variance_threshold', VarianceThreshold(threshold=0.0)),
    ('feature_selection', 'passthrough'),
    ('classifier', 'passthrough')
])

# Define parameter grid as a list of dictionaries
n_features = X.shape[1]

param_grid = [
    {
        'feature_selection': [feature_selection_methods['select_k_best']],
        'feature_selection__k': [min(n_features, 30)],  # Avoid 'all' if not feasible
        'classifier': [models['logistic_regression']],
        'classifier__C': [0.1],
        'classifier__penalty': ['l2']
    },
    {
        'feature_selection': [feature_selection_methods['pca']],
        'feature_selection__n_components': [min(n_features, n) for n in [5]],
        'classifier': [models['svm']],
        'classifier__C': [0.1],
        'classifier__kernel': ['linear', 'rbf']
    },
    {
        'feature_selection': [feature_selection_methods['model_based']],
        'classifier': [models['random_forest']],
        'classifier__n_estimators': [50],
        'classifier__max_depth': [5],
        'classifier__min_samples_split': [2]
    },
]

# Define scoring metrics
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'f1': make_scorer(f1_score, average='weighted'),
}

# Add roc_auc only for models supporting predict_proba
if any(hasattr(ml_model, "predict_proba") for ml_model in models.values()):
    scoring['roc_auc'] = make_scorer(roc_auc_score, needs_proba=True, multi_class='ovr')

logo = LeaveOneGroupOut()
        
# Estimate total fits: n_splits * n_params
n_splits = logo.get_n_splits(X, y, groups)
n_params = len(param_grid)
total_fits = n_splits * n_params
print(f"Total fits: {total_fits}")
print(f"Number of splits: {n_splits}, Number of parameters: {n_params}")


ml_grid_search = GridSearchCV(
    pipeline,
    param_grid=param_grid,
    cv=logo,
    scoring=scoring,
    refit='f1' if 'f1' in scoring else 'accuracy',
    n_jobs=-1,
    verbose=3
)
ml_grid_search.fit(X, y, groups=groups)

Total fits: 21
Number of splits: 7, Number of parameters: 3
Fitting 7 folds for each of 4 candidates, totalling 28 fits




[CV 1/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2, feature_selection=SelectKBest(), feature_selection__k=30; accuracy: (test=0.770) f1: (test=0.738) roc_auc: (test=0.445) total time=   0.1s
[CV 2/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2, feature_selection=SelectKBest(), feature_selection__k=30; accuracy: (test=0.630) f1: (test=0.492) roc_auc: (test=0.512) total time=   0.1s
[CV 4/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2, feature_selection=SelectKBest(), feature_selection__k=30; accuracy: (test=0.558) f1: (test=0.549) roc_auc: (test=0.550) total time=   0.1s
[CV 3/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2, feature_selection=SelectKBest(), feature_selection__k=30; accuracy: (test=0.714) f1: (test=0.603) roc_auc: (test=0.505) total time=   0.1s
[CV 6/7] END classifier=LogisticRegression(), classifier__C=0.1, classifier__penalty=l2,

In [9]:
# Convert GridSearchCV results into a DataFrame
results_df = pd.DataFrame(ml_grid_search.cv_results_)
results_df
# results_df.T

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_classifier,param_classifier__C,param_classifier__penalty,param_feature_selection,param_feature_selection__k,param_classifier__kernel,...,split0_test_roc_auc,split1_test_roc_auc,split2_test_roc_auc,split3_test_roc_auc,split4_test_roc_auc,split5_test_roc_auc,split6_test_roc_auc,mean_test_roc_auc,std_test_roc_auc,rank_test_roc_auc
0,0.115433,0.033127,0.016001,0.013844,LogisticRegression(),0.1,l2,SelectKBest(),30.0,,...,0.445254,0.511885,0.504831,0.550254,0.423148,0.464716,0.526238,0.489475,0.042681,4
1,6.526261,0.902807,0.335205,0.147531,SVC(probability=True),0.1,,PCA(),,linear,...,0.554041,0.427993,0.471168,0.527423,0.466878,0.50177,0.57089,0.50288,0.047444,3
2,9.325265,2.861466,0.758177,0.336375,SVC(probability=True),0.1,,PCA(),,rbf,...,0.588346,0.560497,0.502696,0.509874,0.575323,0.516747,0.456482,0.529995,0.043292,2
3,10.009294,1.366763,0.018162,0.002325,RandomForestClassifier(),,,SelectFromModel(estimator=RandomForestClassifi...,,,...,0.412124,0.638823,0.538811,0.697252,0.534456,0.472859,0.539587,0.547702,0.088617,1
