In [None]:
from FeatureExtraction_v2 import *
import os
import pandas as pd
import numpy as np
import seaborn as sns

### Prepare data 

In [None]:
data_dir = '/Users/jazlynn/Downloads/neurons-smr-format-sorted/'

results_dict = {'Filename': [],
                'firing rate': [],
                'ifr_mean': [],
                'ifr_skew': [],
                'ifr_kurtosis': [],
                'ifr_fano_factor': [],
                'ISI_mean': [],
                'ISI_skew': [],
                'ISI_kurtosis': [],
                'ISI_cv': [],
                'num_bursts': [], 
                'mean_surprise': [],
                'burst_index': [],
                'mean_burst_duration': [],
                'var_burst_duration': [],
                'mean_interburst_duration': [],
                'var_interburst_duration': [],
                'max_peak_freq': [],
                'delta_band': [],
                'theta_band': [],
                'alpha_band': [],
                'beta_band': [],
                'gamma_band': []}

for file in os.listdir(data_dir):
    if file.endswith('.smr'):
        analogsignal, spike_times, sampling_frequency, time = load_spiketrain(os.path.join(data_dir,file))
        results_dict['Filename'].append(file)
        
        results_dict['firing rate'].append(get_firing_rate(spike_times, analogsignal, sampling_frequency))
        ifr_ls, time_bins = calculate_instantaneous_firing_rate(spike_times, analogsignal, sampling_frequency, 0.05, 0.1)
        fano_factor, ifr_mean, ifr_skew, ifr_kurtosis = get_ifr_metrics(ifr_ls)
        results_dict['ifr_mean'].append(ifr_mean)
        results_dict['ifr_skew'].append(ifr_skew)
        results_dict['ifr_kurtosis'].append(ifr_kurtosis)
        results_dict['ifr_fano_factor'].append(fano_factor)
        
        ISI_cv, ISI_skew, ISI_kurtosis, ISI_mean, ISI_mode = get_ISI_metrics(spike_times)
        results_dict['ISI_cv'].append(ISI_cv)
        results_dict['ISI_skew'].append(ISI_skew)
        results_dict['ISI_kurtosis'].append(ISI_kurtosis)
        results_dict['ISI_mean'].append(ISI_mean)
        
        burst_dict = burst_detection_neuroexplorer(spike_times, 
                                                   analogsignal, 
                                                   sampling_frequency, 
                                                   min_surprise = 5, 
                                                   min_numspikes = 3)
        num_bursts, mean_surprise, burst_index_poisson, mean_burst_duration, var_burst_duration, mean_interburst_duration, var_interburst_duration = get_burst_metrics(burst_dict, spike_times)
        burst_index = get_burst_index(spike_times)
        results_dict['num_bursts'].append(num_bursts)
        results_dict['mean_surprise'].append(mean_surprise)
        results_dict['burst_index'].append(burst_index) # not using poisson surprise
        results_dict['mean_burst_duration'].append(mean_burst_duration)
        results_dict['var_burst_duration'].append(var_burst_duration)
        results_dict['mean_interburst_duration'].append(mean_interburst_duration)
        results_dict['var_interburst_duration'].append(var_interburst_duration)
        
        max_peak_freq, freq_peaks, peak_bands, freq_magnitude = get_synchrony_features2(spike_times, 
                                                                                       time_bin_size = 0.01, 
                                                                                       max_lag_time = 0.5, 
                                                                                       significance_level = 0.05, 
                                                                                       to_plot = False)
        results_dict['max_peak_freq'].append(max_peak_freq)
        results_dict['delta_band'].append(int('delta' in peak_bands))
        results_dict['theta_band'].append(int('theta' in peak_bands))
        results_dict['alpha_band'].append(int('alpha' in peak_bands))
        results_dict['beta_band'].append(int('beta' in peak_bands))
        results_dict['gamma_band'].append(int('gamma' in peak_bands))
        
        
results_df = pd.DataFrame(results_dict)
results_df.head()
results_df.to_csv('ExtractedFeatures_v3.csv')

### Model Training

In [None]:
metadata = pd.read_csv('/Users/jazlynn/Downloads/bme1500-project-5-metadata.csv')
metadata

In [None]:
results_df = pd.read_csv('ExtractedFeatures_v3.csv')

print(results_df.isnull().any())
results_df.fillna(0,inplace=True)

combined_data = pd.merge(metadata, results_df,on='Filename')
combined_data['type'] = combined_data['Neuron'] + '_' + combined_data['Hemisphere']
combined_data.head()

In [None]:
subset_left_STN = combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Target']=='STN')]
subset_left_GPi = combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Target']=='GPi')]
subset_right_STN = combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Target']=='STN')]
subset_right_GPi = combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Target']=='GPi')]

print('left STN', len(combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Neuron']=='STN')]))
print('right STN', len(combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Neuron']=='STN')]))
print('left SNr', len(combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Neuron']=='SNr')]))
print('right SNr', len(combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Neuron']=='SNr')]))
print('left HFD', len(combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Neuron']=='HFD')]))
print('right HFD', len(combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Neuron']=='HFD')]))
print('left BOR', len(combined_data.loc[(combined_data['Hemisphere']=='L') & (combined_data['Neuron']=='BOR')]))
print('right BOR', len(combined_data.loc[(combined_data['Hemisphere']=='R') & (combined_data['Neuron']=='BOR')]))

In [None]:
features = combined_data.columns
features = features.drop(['Filename','Unnamed: 0', 'Patient ID', 'Target', 'Neuron','Hemisphere','firing rate','type'])
print(features)

In [None]:
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from sklearn.svm import SVC

In [None]:
for population in [subset_left_GPi,subset_right_GPi,subset_left_STN,subset_right_STN]:
    print('Classifying region ' + population['Target'].iloc[0] + ' hemi ' + population['Hemisphere'].iloc[0])
    X_train, X_test, y_train, y_test = train_test_split(population[features],
                                                        population['Neuron'],
                                                        stratify=population['Neuron'],
                                                        test_size=0.2, random_state=42)
    RF = RandomForestClassifier(n_estimators=100, random_state=42)
    RF.fit(X_train, y_train)
    train_pred = RF.predict(X_train)
    test_pred = RF.predict(X_test)
    print('train accuracy', accuracy_score(y_train,train_pred))
    print('test accuracy', accuracy_score(y_test,test_pred))
    # cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
    # ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')
    
    feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
    feature_importance = feature_importance.sort_values(by='importance',ascending=False)
    # fig = sns.barplot(data=feature_importance,x='feature',y='importance')
    # plt.xticks(rotation=90)
    # plt.savefig('feature_importance_' + population['Target'].iloc[0] + ' hemi ' + population['Hemisphere'].iloc[0] + '.svg')
    
    
    pruned_features = feature_importance['feature'].iloc[:4].tolist()
    print('selected features', pruned_features)
    
    RF = RandomForestClassifier(n_estimators=100, random_state=42)
    RF.fit(X_train[pruned_features], y_train)
    train_pred = RF.predict(X_train[pruned_features])
    test_pred = RF.predict(X_test[pruned_features])
    print('pruned train accuracy', accuracy_score(y_train,train_pred))
    print('pruned test accuracy', accuracy_score(y_test,test_pred))
    
    parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
    RF_pruned = RandomForestClassifier(random_state=42)
    hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
    hypparam_opt.fit(X_train[pruned_features],y_train)
    
    opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
    opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
    cm = confusion_matrix(y_test, opt_test_pred, labels=hypparam_opt.classes_)
    ConfusionMatrixDisplay.from_estimator(hypparam_opt,X_test[pruned_features], y_test, labels=hypparam_opt.classes_,cmap='gray')
    plt.savefig('Confusion_' + population['Target'].iloc[0] + ' hemi ' + population['Hemisphere'].iloc[0] + '.svg')
    print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
    print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))
    
    fig,ax=plt.subplots()
    sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
    ax.set_ylabel('Gini importance')
    ax.tick_params(axis='x', labelrotation = 90)
    
    ### baseline model ###
    print('baseline logistic regression optimized features')
    LR = LogisticRegression(random_state=42, max_iter=150,penalty=None)
    LR.fit(X_train[pruned_features], y_train)
    train_pred = LR.predict(X_train[pruned_features])
    test_pred = LR.predict(X_test[pruned_features])
    print('baseline train accuracy', accuracy_score(y_train,train_pred))
    print('baseline test accuracy', accuracy_score(y_test,test_pred))
    
    ### SVM ###
    print('SVM model')
    print('SVM before optimization')
    svm = SVC(gamma='auto',random_state=42)
    svm.fit(X_train[pruned_features], y_train)
    train_pred = svm.predict(X_train[pruned_features])
    test_pred = svm.predict(X_test[pruned_features])
    
    pipe = Pipeline([('scaler', StandardScaler()), ('svc', SVC(gamma='auto',random_state=42))])
    pipe.fit(X_train[pruned_features], y_train)
    train_pred = pipe.predict(X_train[pruned_features])
    test_pred = pipe.predict(X_test[pruned_features])
    
    print('svm train accuracy', accuracy_score(y_train,train_pred))
    print('svm test accuracy', accuracy_score(y_test,test_pred))
    
    svm_parameters = {'svc__C': [0.1, 1, 10], 'svc__gamma': [1, 0.1, 0.01, 0.001]}
    # SVM_pruned = SVC(random_state=42,kernel='rbf')
    # SVM_hypparam_opt = GridSearchCV(SVM_pruned,svm_parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
    # SVM_hypparam_opt.fit(X_train[pruned_features],y_train)
    
    pipe_pruned = Pipeline([('scaler', StandardScaler()), ('svc', SVC(random_state=42))])
    SVM_hypparam_opt = GridSearchCV(pipe_pruned,svm_parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
    SVM_hypparam_opt.fit(X_train[pruned_features],y_train)
    
    opt_train_pred = SVM_hypparam_opt.predict(X_train[pruned_features])
    opt_test_pred = SVM_hypparam_opt.predict(X_test[pruned_features], pos_label=population['Neuron'].iloc[0])
    print('SVM optimized train accuracy', accuracy_score(y_train,opt_train_pred))
    print('SVM optimized test accuracy', accuracy_score(y_test,opt_test_pred))

#### What if dont split hemisphere

In [None]:
subset_STN = combined_data.loc[(combined_data['Target']=='STN')]
subset_GPi = combined_data.loc[(combined_data['Target']=='GPi')]

for population in [subset_GPi,subset_STN]:
    print('Classifying region ' + population['Target'].iloc[0] + ' hemi ' + population['Hemisphere'].iloc[0])
    X_train, X_test, y_train, y_test = train_test_split(population[features],
                                                        population['Neuron'],
                                                        stratify=population['Neuron'],
                                                        test_size=0.2, random_state=42)
    RF = RandomForestClassifier(n_estimators=100, random_state=42)
    RF.fit(X_train, y_train)
    train_pred = RF.predict(X_train)
    test_pred = RF.predict(X_test)
    print('train accuracy', accuracy_score(y_train,train_pred))
    print('test accuracy', accuracy_score(y_test,test_pred))
    cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
    ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')
    
    feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
    feature_importance = feature_importance.sort_values(by='importance',ascending=False)
    
    pruned_features = feature_importance['feature'].iloc[:4].tolist()
    print('selected features', pruned_features)
    
    RF = RandomForestClassifier(n_estimators=100, random_state=42)
    RF.fit(X_train[pruned_features], y_train)
    train_pred = RF.predict(X_train[pruned_features])
    test_pred = RF.predict(X_test[pruned_features])
    print('pruned train accuracy', accuracy_score(y_train,train_pred))
    print('pruned test accuracy', accuracy_score(y_test,test_pred))
    
    parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
    RF_pruned = RandomForestClassifier(random_state=42)
    hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
    hypparam_opt.fit(X_train[pruned_features],y_train)
    
    opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
    opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
    print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
    print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))
    
    # fig,ax=plt.subplots()
    # sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
    # ax.set_ylabel('Gini importance')
    # ax.tick_params(axis='x', labelrotation = 90)
    
    ### baseline model ###
    
    print('baseline logistic regression optimized features')
    LR = LogisticRegression(random_state=42, max_iter=150,penalty=None)
    LR.fit(X_train[pruned_features], y_train)
    train_pred = LR.predict(X_train[pruned_features])
    test_pred = LR.predict(X_test[pruned_features])
    print('baseline train accuracy', accuracy_score(y_train,train_pred))
    print('baseline test accuracy', accuracy_score(y_test,test_pred))
    
    ### SVM ###
    print('SVM model')
    print('SVM before optimization')
    svm = SVC(gamma='auto',random_state=42)
    svm.fit(X_train[pruned_features], y_train)
    train_pred = svm.predict(X_train[pruned_features])
    test_pred = svm.predict(X_test[pruned_features])
    print('svm train accuracy', accuracy_score(y_train,train_pred))
    print('svm test accuracy', accuracy_score(y_test,test_pred))
    
    svm_parameters = {'C': [0.1, 1, 10], 'gamma': [1, 0.1, 0.01, 0.001]}
    SVM_pruned = SVC(random_state=42,kernel='rbf')
    SVM_hypparam_opt = GridSearchCV(SVM_pruned,svm_parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
    SVM_hypparam_opt.fit(X_train[pruned_features],y_train)
    
    opt_train_pred = SVM_hypparam_opt.predict(X_train[pruned_features])
    opt_test_pred = SVM_hypparam_opt.predict(X_test[pruned_features])
    print('SVM optimized train accuracy', accuracy_score(y_train,opt_train_pred))
    print('SVM optimized test accuracy', accuracy_score(y_test,opt_test_pred))

#### can left side generalise to right and vice versa?

In [None]:
print('train left GPi')
X_train, X_test, y_train, y_test = train_test_split(subset_left_GPi[features],
                                                        subset_left_GPi['Neuron'],
                                                        stratify=subset_left_GPi['Neuron'],
                                                        test_size=0.2, random_state=42)
RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train, y_train)
train_pred = RF.predict(X_train)
test_pred = RF.predict(X_test)
print('train accuracy', accuracy_score(y_train,train_pred))
print('test accuracy', accuracy_score(y_test,test_pred))
# cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
# ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')

feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
feature_importance = feature_importance.sort_values(by='importance',ascending=False)

pruned_features = feature_importance['feature'].iloc[:4].tolist()
print('selected features', pruned_features)

RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train[pruned_features], y_train)
train_pred = RF.predict(X_train[pruned_features])
test_pred = RF.predict(X_test[pruned_features])
print('pruned train accuracy', accuracy_score(y_train,train_pred))
print('pruned test accuracy', accuracy_score(y_test,test_pred))

parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
RF_pruned = RandomForestClassifier(random_state=42)
hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
hypparam_opt.fit(X_train[pruned_features],y_train)

opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))

right_pred = hypparam_opt.predict(subset_right_GPi[pruned_features])
print('right accuracy', accuracy_score(subset_right_GPi['Neuron'],right_pred))

# fig,ax=plt.subplots()
# sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
# ax.set_ylabel('Gini importance')
# ax.tick_params(axis='x', labelrotation = 90)

print('train left STN')
X_train, X_test, y_train, y_test = train_test_split(subset_left_STN[features],
                                                        subset_left_STN['Neuron'],
                                                        stratify=subset_left_STN['Neuron'],
                                                        test_size=0.2, random_state=42)
RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train, y_train)
train_pred = RF.predict(X_train)
test_pred = RF.predict(X_test)
print('train accuracy', accuracy_score(y_train,train_pred))
print('test accuracy', accuracy_score(y_test,test_pred))
# cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
# ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')

feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
feature_importance = feature_importance.sort_values(by='importance',ascending=False)

pruned_features = feature_importance['feature'].iloc[:4].tolist()
print('selected features', pruned_features)

RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train[pruned_features], y_train)
train_pred = RF.predict(X_train[pruned_features])
test_pred = RF.predict(X_test[pruned_features])
print('pruned train accuracy', accuracy_score(y_train,train_pred))
print('pruned test accuracy', accuracy_score(y_test,test_pred))

parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
RF_pruned = RandomForestClassifier(random_state=42)
hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
hypparam_opt.fit(X_train[pruned_features],y_train)

opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))

right_pred = hypparam_opt.predict(subset_right_STN[pruned_features])
print('right accuracy', accuracy_score(subset_right_STN['Neuron'],right_pred))

# fig,ax=plt.subplots()
# sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
# ax.set_ylabel('Gini importance')
# ax.tick_params(axis='x', labelrotation = 90)

print('train right GPi')
X_train, X_test, y_train, y_test = train_test_split(subset_right_GPi[features],
                                                        subset_right_GPi['Neuron'],
                                                        stratify=subset_right_GPi['Neuron'],
                                                        test_size=0.2, random_state=42)
RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train, y_train)
train_pred = RF.predict(X_train)
test_pred = RF.predict(X_test)
print('train accuracy', accuracy_score(y_train,train_pred))
print('test accuracy', accuracy_score(y_test,test_pred))
# cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
# ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')

feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
feature_importance = feature_importance.sort_values(by='importance',ascending=False)

pruned_features = feature_importance['feature'].iloc[:4].tolist()
print('selected features', pruned_features)

RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train[pruned_features], y_train)
train_pred = RF.predict(X_train[pruned_features])
test_pred = RF.predict(X_test[pruned_features])
print('pruned train accuracy', accuracy_score(y_train,train_pred))
print('pruned test accuracy', accuracy_score(y_test,test_pred))

parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
RF_pruned = RandomForestClassifier(random_state=42)
hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
hypparam_opt.fit(X_train[pruned_features],y_train)

opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))

left_pred = hypparam_opt.predict(subset_left_GPi[pruned_features])
print('left accuracy', accuracy_score(subset_left_GPi['Neuron'],left_pred))

# fig,ax=plt.subplots()
# sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
# ax.set_ylabel('Gini importance')
# ax.tick_params(axis='x', labelrotation = 90)

print('train right STN')
X_train, X_test, y_train, y_test = train_test_split(subset_right_STN[features],
                                                        subset_right_STN['Neuron'],
                                                        stratify=subset_right_STN['Neuron'],
                                                        test_size=0.2, random_state=42)
RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train, y_train)
train_pred = RF.predict(X_train)
test_pred = RF.predict(X_test)
print('train accuracy', accuracy_score(y_train,train_pred))
print('test accuracy', accuracy_score(y_test,test_pred))
# cm = confusion_matrix(y_test, test_pred, labels=RF.classes_)
# ConfusionMatrixDisplay.from_estimator(RF,X_test, y_test, labels=RF.classes_,cmap='gray')

feature_importance = pd.DataFrame({'feature': RF.feature_names_in_, 'importance':RF.feature_importances_})
feature_importance = feature_importance.sort_values(by='importance',ascending=False)

pruned_features = feature_importance['feature'].iloc[:4].tolist()
print('selected features', pruned_features)

RF = RandomForestClassifier(n_estimators=100, random_state=42)
RF.fit(X_train[pruned_features], y_train)
train_pred = RF.predict(X_train[pruned_features])
test_pred = RF.predict(X_test[pruned_features])
print('pruned train accuracy', accuracy_score(y_train,train_pred))
print('pruned test accuracy', accuracy_score(y_test,test_pred))

parameters = {'n_estimators':[100,200,400], 'max_depth': [5,10,None],'min_samples_leaf': [0.001,0.01,0.1],'min_samples_split': [0.001,0.01,0.1]}
RF_pruned = RandomForestClassifier(random_state=42)
hypparam_opt = GridSearchCV(RF_pruned,parameters,cv=5,scoring='accuracy',refit=True,n_jobs=4)
hypparam_opt.fit(X_train[pruned_features],y_train)

opt_train_pred = hypparam_opt.predict(X_train[pruned_features])
opt_test_pred = hypparam_opt.predict(X_test[pruned_features])
print('optimized train accuracy', accuracy_score(y_train,opt_train_pred))
print('optimized test accuracy', accuracy_score(y_test,opt_test_pred))

left_pred = hypparam_opt.predict(subset_left_STN[pruned_features])
print('left accuracy', accuracy_score(subset_left_STN['Neuron'],left_pred))

# fig,ax=plt.subplots()
# sns.barplot(feature_importance,x='feature',y='importance',ax=ax)
# ax.set_ylabel('Gini importance')
# ax.tick_params(axis='x', labelrotation = 90)


### Figures

In [None]:
for i,population in enumerate([subset_left_GPi,subset_right_GPi,subset_left_STN,subset_right_STN]):
    sns.pairplot(population[['ISI_mean', 'ifr_mean', 'mean_surprise','Neuron']],hue='Neuron')

In [None]:
subset_STN_STN = combined_data.loc[combined_data['Neuron']=='STN']
subset_STN_SNr = combined_data.loc[combined_data['Neuron']=='SNr']
for i,population in enumerate([subset_STN_STN, subset_STN_SNr]):
    sns.pairplot(population[['ISI_mean', 'ifr_mean', 'mean_surprise','Hemisphere']],hue='Hemisphere')

In [None]:

from sklearn.svm import SVC
svm_clf = SVC(gamma='auto',random_state=42)
svm_clf.fit(X_train, y_train)

SVM_train_pred = svm_clf.predict(X_train)
SVM_test_pred = svm_clf.predict(X_test)

print('train precision:', precision_score(y_train,SVM_train_pred,average=None))
print('train recall:', recall_score(y_train,SVM_train_pred,average=None))
print('train f1:', f1_score(y_train,SVM_train_pred,average=None))
matrix = confusion_matrix(y_train, SVM_train_pred)
print('train accuracy:', matrix.diagonal()/matrix.sum(axis=1))

print('test precision:', precision_score(y_test,SVM_test_pred,average=None))
print('test recall:', recall_score(y_test,SVM_test_pred,average=None))
print('test f1:', f1_score(y_test,SVM_test_pred,average=None))
matrix = confusion_matrix(y_test, SVM_test_pred)
print('test accuracy:', matrix.diagonal()/matrix.sum(axis=1))

print('macro averaged')
print('train precision:', precision_score(y_train,SVM_train_pred,average='macro'))
print('train recall:', recall_score(y_train,SVM_train_pred,average='macro'))
print('train f1:', f1_score(y_train,SVM_train_pred,average='macro'))
print('train accuracy', accuracy_score(y_train,SVM_train_pred))

print('test precision:', precision_score(y_test,SVM_test_pred,average='macro'))
print('test recall:', recall_score(y_test,SVM_test_pred,average='macro'))
print('test f1:', f1_score(y_test,SVM_test_pred,average='macro'))
print('test accuracy', accuracy_score(y_test,SVM_test_pred))

In [None]:
from sklearn.ensemble import GradientBoostingClassifier
GBDT = GradientBoostingClassifier(n_estimators=1000, learning_rate=0.005, max_depth=None, random_state=42,verbose=2).fit(X_train, y_train)

GBDT_train_pred = GBDT.predict(X_train)
GBDT_test_pred = GBDT.predict(X_test)

print('train precision:', precision_score(y_train,GBDT_train_pred,average=None))
print('train recall:', recall_score(y_train,GBDT_train_pred,average=None))
print('train f1:', f1_score(y_train,GBDT_train_pred,average=None))
matrix = confusion_matrix(y_train, GBDT_train_pred)
print('train accuracy:', matrix.diagonal()/matrix.sum(axis=1))

print('test precision:', precision_score(y_test,GBDT_test_pred,average=None))
print('test recall:', recall_score(y_test,GBDT_test_pred,average=None))
print('test f1:', f1_score(y_test,GBDT_test_pred,average=None))
matrix = confusion_matrix(y_test, GBDT_test_pred)
print('test accuracy:', matrix.diagonal()/matrix.sum(axis=1))

print('macro averaged')
print('train precision:', precision_score(y_train,GBDT_train_pred,average='macro'))
print('train recall:', recall_score(y_train,GBDT_train_pred,average='macro'))
print('train f1:', f1_score(y_train,GBDT_train_pred,average='macro'))
print('train accuracy', accuracy_score(y_train,GBDT_train_pred))

print('test precision:', precision_score(y_test,GBDT_test_pred,average='macro'))
print('test recall:', recall_score(y_test,GBDT_test_pred,average='macro'))
print('test f1:', f1_score(y_test,GBDT_test_pred,average='macro'))
print('test accuracy', accuracy_score(y_test,GBDT_test_pred))

In [None]:
from sklearn.feature_selection import RFE

X_train, X_test, y_train, y_test = train_test_split(subset_right_STN[features],
                                                        subset_right_STN['Neuron'],
                                                        stratify=subset_right_STN['Neuron'],
                                                        test_size=0.2, random_state=42)

for i in range(1,10):
    RF = RandomForestClassifier(n_estimators=200,random_state=18)
    rfe = RFE(estimator=RF, n_features_to_select=i, step=1, verbose=0)
    rfe.fit(X_train, y_train)
    print(rfe.get_feature_names_out())
    # print(rfe.ranking_)
    print(accuracy_score(y_test,rfe.predict(X_test)))



### FIGURES

In [None]:
fig = sns.pairplot(subset_right_STN[['ISI_mean', 'ifr_mean', 'mean_surprise','Neuron','var_burst_duration']],hue='Neuron',plot_kws=dict(alpha=0.7),palette=['tab:green','tab:orange'])
fig.savefig('Pairplot_right_STN.svg')

In [None]:
fig = sns.pairplot(subset_right_GPi[['ISI_mean', 'ifr_mean', 'mean_surprise','Neuron','num_bursts']],hue='Neuron',plot_kws=dict(alpha=0.7),palette=['tab:blue','tab:red'])
fig.savefig('Pairplot_right_GPi.svg')

In [None]:
fig = sns.pairplot(subset_left_STN[['ISI_mean', 'ifr_mean', 'mean_surprise','Neuron','var_burst_duration']],hue='Neuron',plot_kws=dict(alpha=0.7),palette=['tab:green','tab:orange'])
fig.savefig('Pairplot_left_STN.svg')

In [None]:
subset_left_GPi_nooutlier = subset_left_GPi.loc[subset_left_GPi['outlier']==0]
fig = sns.pairplot(subset_left_GPi_nooutlier[['ISI_mean', 'ifr_mean', 'mean_surprise','Neuron','mean_burst_duration']],hue='Neuron',plot_kws=dict(alpha=0.7),palette=['tab:blue','tab:red'])
fig.savefig('Pairplot_left_GPi.svg')

In [None]:
outlier_min = np.mean(subset_left_GPi['mean_burst_duration']) - 2* np.std(subset_left_GPi['mean_burst_duration'])
outlier_max = np.mean(subset_left_GPi['mean_burst_duration']) + 2* np.std(subset_left_GPi['mean_burst_duration'])
subset_left_GPi['outlier'] = 0
for i,val in enumerate(subset_left_GPi['mean_burst_duration']):
    if val < outlier_min or val > outlier_max:
        subset_left_GPi['outlier'].iloc[i] = 1

In [None]:
from scipy.signal import find_peaks, butter, filtfilt
lowcut = 300  # Low cutoff frequency in Hz
highcut = 6000  # High cutoff frequency in Hz
fs = 12500  # Sampling frequency in Hz

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

neuron_name = 'neuron_012'
analogsignal, spike_times, sampling_frequency, time = load_spiketrain(os.path.join(data_dir, neuron_name + '.smr'))
filtered_data = butter_bandpass_filter(analogsignal, lowcut, highcut, fs)

##### PLOT SIGNALS #####

fig, ax = plt.subplots(3, sharex = True, )
fig.suptitle(neuron_name, fontsize=16)    

ax[0].plot(time,analogsignal,'tab:blue')
ax[0].set_xlim(1,5)
ax[0].set_ylabel('Voltage (mV)')

ax[1].plot(time,filtered_data,'tab:blue', alpha=0.7)
ax[1].set_ylabel('Voltage (mV)')

ax[2].eventplot(spike_times, color='black')
ax[2].set_xlabel("Time (s)")
fig.savefig('Example_filtering.svg')

# fig, ax = plt.subplots(3, sharex = True, )
# fig.suptitle(neuron_name, fontsize=16)    

# ax[0].plot(time,analogsignal,'tab:blue')
# ax[0].set_xlim(3,3.5)
# ax[0].set_ylabel('Voltage')

# ax[1].plot(time,filtered_data,'tab:blue', alpha=0.7)
# ax[1].set_ylabel('Voltage')

# ax[2].eventplot(spike_times, color='black')
# ax[2].set_xlabel("Time (s)")

In [None]:
spike_peaks, _ = find_peaks(-filtered_data, height=-np.percentile(filtered_data,0.7))

fig, ax = plt.subplots(3, sharex = True, )
fig.suptitle(neuron_name, fontsize=16)    

ax[0].plot(time,filtered_data,'tab:blue',alpha=0.7)
ax[0].set_xlim(3,3.5)
ax[0].set_ylabel('Voltage (mV)')
ax[1].plot(time,filtered_data,'tab:blue',alpha=0.7)
ax[1].scatter(spike_peaks/sampling_frequency,analogsignal[spike_peaks], color='red',s=10,alpha=0.9)
ax[1].set_ylabel('Voltage (mV)')
ax[2].eventplot(spike_times, color='black')
ax[2].set_xlabel("Time (s)")
fig.savefig('Example_spiketrain.svg')


In [None]:
spike_waveforms = np.zeros((len(spike_times),50))
window_size = 25
for i,peak in enumerate(spike_times):
    spike_waveforms[i,:] = filtered_data[int(peak*12500) - window_size: int(peak*12500) + window_size]

# Apply PCA for dimensionality reduction
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
pca = PCA(n_components=3)
waveform_features = pca.fit_transform(spike_waveforms)

si_score = []
cluster_numbers = [2,3,4,5]
for n_clusters in cluster_numbers:
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
    cluster_labels = kmeans.fit_predict(waveform_features)
    si_score.append(silhouette_score(waveform_features, cluster_labels, metric='euclidean'))
    
opt_n_clusters = cluster_numbers[np.argmax(si_score)]
print(opt_n_clusters)

kmeans = KMeans(n_clusters=opt_n_clusters, random_state=0, n_init='auto')
cluster_labels = kmeans.fit_predict(waveform_features)
fig = plt.figure()
ax = plt.axes(projection='3d')
scatter = ax.scatter(waveform_features[:,0], waveform_features[:,1], waveform_features[:,2],c=cluster_labels,cmap='bwr',alpha=0.6)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
legend1 = ax.legend(*scatter.legend_elements(), title="Clusters")
ax.add_artist(legend1)
fig.savefig('Example_neuron_sortfeatures.svg')

In [None]:
fig,axes=plt.subplots(1,3,)
cluster_waveforms1 = spike_waveforms[cluster_labels == 1,:]
for spike in range(len(cluster_waveforms1)):
    axes[0].plot([i for i in np.linspace(-2,2,50)],cluster_waveforms1[spike,:],alpha=0.05,color = 'blue',linewidth=1)
axes[0].plot([i for i in np.linspace(-2,2,50)],np.mean(cluster_waveforms1,axis=0), color='blue',linewidth=2,alpha=0.8)
axes[0].set_box_aspect(1)

cluster_waveforms2 = spike_waveforms[cluster_labels == 0,:]
for spike in range(len(cluster_waveforms2)):
     axes[1].plot([i for i in np.linspace(-2,2,50)],cluster_waveforms2[spike,:],alpha=0.05,color = 'red',linewidth=1)
axes[1].plot([i for i in np.linspace(-2,2,50)],np.mean(cluster_waveforms2,axis=0), color='red',linewidth=2,alpha=0.8)
axes[1].set_box_aspect(1)
axes[0].set_title('Cluster 1')
axes[1].set_title('Cluster 2')

for spike in range(len(spike_waveforms)):
     axes[2].plot([i for i in np.linspace(-2,2,50)],spike_waveforms[spike,:],alpha=0.05,color = 'black',linewidth=1)
axes[2].plot([i for i in np.linspace(-2,2,50)],np.mean(spike_waveforms,axis=0), color='black',linewidth=2,alpha=0.8)
axes[2].set_box_aspect(1)

fig.supxlabel('Time (ms)')
fig.supylabel('Voltage (mV)')
fig.tight_layout()
fig.savefig('Example_cluster.svg')
# for spike in range(len(spike_waveforms)):
#     plt.plot(spike_waveforms[spike,:],alpha=0.1,color='tab:blue',linewidth = 1)
    
# plt.plot(np.mean(spike_waveforms,axis=0),color='black')

In [None]:
window=25
snr_ls = []
isi_isol_ls = []
mean_isi_ls = []
num_refrac_ls = [] 
for file in os.listdir(data_dir):
    if file.endswith('.smr'):
        analogsignal, spike_times, sampling_frequency, time = load_spiketrain(os.path.join(data_dir,file))
        waveforms = np.zeros((len(spike_times), window*2))
        for i, spike in enumerate(spike_times):
            if spike*sampling_frequency - window >= 0 and spike*sampling_frequency + window <= len(analogsignal):
                waveforms[i,:] = analogsignal[int(spike*sampling_frequency)-window:int(spike*sampling_frequency)+window]
        
        waveforms = waveforms[~np.all(waveforms == 0, axis=1)]
        mean_waveform = np.mean(waveforms,axis=0)
        snr = (np.max(mean_waveform) - np.min(mean_waveform) )/ (np.std(mean_waveform) * 2)
        snr_ls.append(snr)
        
        ISI = np.diff(spike_times)
        ISI_cv = np.std(ISI) / np.mean(ISI)
        isi_isol_ls.append(1/ISI_cv)
        mean_isi_ls.append(np.mean(ISI))
        
        num_refrac_ls.append(len(ISI[ISI<=0.001]))

fig,ax=plt.subplots(1,2)        
sns.histplot(snr_ls,ax=ax[0])
sns.histplot(isi_isol_ls,ax=ax[1])
# sns.histplot(num_refrac_ls,ax=ax[2])
# ax[1].axvline(x=0.001)
# ax[1].set_xlim(0,0.1)

ax[0].set_xlabel('Mean signal-to-noise ratio')
ax[1].set_xlabel('ISI isolation')

ax[0].set_box_aspect(1)
ax[1].set_box_aspect(1)

fig.tight_layout()
# fig.savefig('spikesort_metrics.svg')

print('mean snr', np.mean(snr_ls))
print('isolation ', np.mean(isi_isol_ls))