In [1]:
from FeatureExtraction import *
import os
import pandas as pd
import numpy as np

In [2]:
data_dir = '/Users/jazlynn/Downloads/neurons-smr-format-sorted/'

results_dict = {'Filename': [],
                'firing rate': [],
                'ifr_mean': [],
                'ifr_skew': [],
                'ifr_kurtosis': [],
                'ifr_fano_factor': [],
                'baseline_amplitude': [],
                'peak_trough_amplitude': [],
                'waveform_duration': [],
                'waveform_halfwidth': [], # unlikely to use because of NaNs
                'waveform_PT_ratio': [],
                'waveform_TP_time': [],
                'ISI_mean': [],
                'ISI_skew': [],
                'ISI_kurtosis': [],
                'ISI_cv': [],
                'num_bursts': [], 
                'mean_surprise': [],
                'burst_index': [],
                'mean_burst_duration': [],
                'var_burst_duration': [],
                'mean_interburst_duration': [],
                'var_interburst_duration': [],
                'max_peak_freq': [],
                'delta_band': [],
                'theta_band': [],
                'alpha_band': [],
                'beta_band': [],
                'gamma_band': [],
                'delta_band_freq': [],
                'theta_band_freq': [],
                'delta_band_magnitude': [],
                'theta_band_magnitude': []}

for file in os.listdir(data_dir):
    if file.endswith('.smr'):
        analogsignal, spike_times, sampling_frequency, time = load_spiketrain(os.path.join(data_dir,file))
        results_dict['Filename'].append(file)
        
        results_dict['firing rate'].append(get_firing_rate(spike_times, analogsignal, sampling_frequency))
        ifr_ls, time_bins = calculate_instantaneous_firing_rate(spike_times, analogsignal, sampling_frequency, 0.05, 0.1)
        fano_factor, ifr_mean, ifr_skew, ifr_kurtosis = get_ifr_metrics(ifr_ls)
        results_dict['ifr_mean'].append(ifr_mean)
        results_dict['ifr_skew'].append(ifr_skew)
        results_dict['ifr_kurtosis'].append(ifr_kurtosis)
        results_dict['ifr_fano_factor'].append(fano_factor)
        
        peak_trough_amplitude, _ = get_mean_amplitude(spike_times, analogsignal, sampling_frequency, 25)
        baseline_amplitude, _ = get_mean_amplitude2(spike_times, analogsignal, sampling_frequency, 25)
        waveform_duration, waveform_halfwidth, waveform_PT_ratio, waveform_TP_time = get_ecephys_waveform_metrics(spike_times, analogsignal, 25, sampling_frequency)
        results_dict['baseline_amplitude'].append(baseline_amplitude)
        results_dict['peak_trough_amplitude'].append(peak_trough_amplitude)
        
        results_dict['waveform_duration'].append(waveform_duration)
        results_dict['waveform_halfwidth'].append(waveform_halfwidth)
        results_dict['waveform_PT_ratio'].append(waveform_PT_ratio)
        results_dict['waveform_TP_time'].append(waveform_TP_time)
        
        ISI_cv, ISI_skew, ISI_kurtosis, ISI_mean, ISI_mode = get_ISI_metrics(spike_times)
        results_dict['ISI_cv'].append(ISI_cv)
        results_dict['ISI_skew'].append(ISI_skew)
        results_dict['ISI_kurtosis'].append(ISI_kurtosis)
        results_dict['ISI_mean'].append(ISI_mean)
        
        burst_dict = burst_detection_neuroexplorer(spike_times, 
                                                   analogsignal, 
                                                   sampling_frequency, 
                                                   min_surprise = 5, 
                                                   min_numspikes = 3)
        num_bursts, mean_surprise, burst_index, mean_burst_duration, var_burst_duration, mean_interburst_duration, var_interburst_duration = get_burst_metrics(burst_dict, spike_times)
        results_dict['num_bursts'].append(num_bursts)
        results_dict['mean_surprise'].append(mean_surprise)
        results_dict['burst_index'].append(burst_index)
        results_dict['mean_burst_duration'].append(mean_burst_duration)
        results_dict['var_burst_duration'].append(var_burst_duration)
        results_dict['mean_interburst_duration'].append(mean_interburst_duration)
        results_dict['var_interburst_duration'].append(var_interburst_duration)
        
        max_peak_freq, freq_peaks, peak_bands, freq_magnitude = get_synchrony_features(spike_times, 
                                                                                       time_bin_size = 0.01, 
                                                                                       max_lag_time = 0.5, 
                                                                                       avg_window_size = 5,
                                                                                       significance_level = 0.05, 
                                                                                       to_plot = False)
        results_dict['max_peak_freq'].append(max_peak_freq)
        results_dict['delta_band'].append(peak_bands.count('delta'))
        results_dict['theta_band'].append(peak_bands.count('theta'))
        results_dict['alpha_band'].append(peak_bands.count('alpha'))
        results_dict['beta_band'].append(peak_bands.count('beta'))
        results_dict['gamma_band'].append(peak_bands.count('gamma'))
        
        # only have delta and theta bands (expected)
        if len(freq_peaks) > 0:
            results_dict['delta_band_freq'].append(np.squeeze([i for i in freq_peaks if 0.1 <= i < 4]))
            results_dict['theta_band_freq'].append(np.squeeze([i for i in freq_peaks if 4 <= i < 8]))
        else:
            results_dict['delta_band_freq'].append(0)
            results_dict['theta_band_freq'].append(0)
        
        if len(freq_magnitude) > 0:
            results_dict['delta_band_magnitude'].append(np.squeeze([freq_magnitude[i] for i,val in enumerate(freq_peaks) if 0.1 <= val < 4]))
            results_dict['theta_band_magnitude'].append(np.squeeze([freq_magnitude[i] for i,val in enumerate(freq_peaks) if 4 <= val < 8]))
        else:
            results_dict['delta_band_magnitude'].append(0)
            results_dict['theta_band_magnitude'].append(0)
        
        
results_df = pd.DataFrame(results_dict)
results_df.head()
results_df.to_csv('ExtractedFeatures_v2.csv')

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)
  autocorrelogram /= np.sum(autocorrelogram)


### Model Training

In [40]:
metadata = pd.read_csv('/Users/jazlynn/Downloads/bme1500-project-5-metadata.csv')
metadata

Unnamed: 0,Filename,Patient ID,Target,Neuron,Hemisphere
0,neuron_001.smr,1,STN,STN,R
1,neuron_002.smr,2,GPi,BOR,L
2,neuron_003.smr,3,GPi,BOR,R
3,neuron_004.smr,3,GPi,BOR,L
4,neuron_005.smr,3,GPi,BOR,L
...,...,...,...,...,...
355,neuron_356.smr,143,STN,STN,R
356,neuron_357.smr,143,STN,SNr,R
357,neuron_358.smr,144,STN,SNr,L
358,neuron_359.smr,145,STN,STN,R


In [65]:
results_df = pd.read_csv('ExtractedFeatures_v1.csv')
# results_df = results_df.mask(results_df.applymap(str).eq('[]'))
results_df.fillna(0,inplace=True)

combined_data = pd.merge(metadata, results_df,on='Filename')
combined_data

Unnamed: 0.1,Filename,Patient ID,Target,Neuron,Hemisphere,Unnamed: 0,firing rate,ifr_mean,ifr_skew,ifr_kurtosis,...,max_peak_freq,delta_band,theta_band,alpha_band,beta_band,gamma_band,delta_band_freq,theta_band_freq,delta_band_magnitude,theta_band_magnitude
0,neuron_001.smr,1,STN,STN,R,300,13.520379,13.469388,1.020497,-0.052522,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.087242,0.020579
1,neuron_002.smr,2,GPi,BOR,L,317,35.229484,35.046154,-0.017229,-0.250216,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.163782,0.049172
2,neuron_003.smr,3,GPi,BOR,R,314,18.319115,18.237179,0.172007,-0.105805,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.090107,0.040045
3,neuron_004.smr,3,GPi,BOR,L,352,16.787130,16.759142,0.123386,-0.008873,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.102068,0.035840
4,neuron_005.smr,3,GPi,BOR,L,351,10.208298,10.153374,0.343673,-0.355903,...,2.969697,1,1,0,0,0,2.969697,5.393939,0.038644,0.012286
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
355,neuron_356.smr,143,STN,STN,R,17,54.666667,54.500000,0.267870,0.560545,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.085823,0.034168
356,neuron_357.smr,143,STN,SNr,R,6,63.066667,62.933333,0.774957,0.042015,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.137768,0.036451
357,neuron_358.smr,144,STN,SNr,L,329,133.600000,133.533333,-0.154999,0.141760,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.101890,0.031467
358,neuron_359.smr,145,STN,STN,R,340,60.076461,59.403974,0.301433,0.545548,...,2.969697,1,1,0,0,0,2.969697,4.909091,0.088075,0.035464


In [86]:
features = results_df.columns
features = features.drop(['Filename','alpha_band','beta_band','gamma_band','Unnamed: 0','delta_band', 'theta_band',
       'delta_band_freq', 'theta_band_freq'])
print(features)

Index(['firing rate', 'ifr_mean', 'ifr_skew', 'ifr_kurtosis',
       'ifr_fano_factor', 'baseline_amplitude', 'peak_trough_amplitude',
       'ISI_mean', 'ISI_skew', 'ISI_kurtosis', 'ISI_cv', 'num_bursts',
       'mean_surprise', 'burst_index', 'mean_burst_duration',
       'var_burst_duration', 'mean_interburst_duration',
       'var_interburst_duration', 'max_peak_freq', 'delta_band_magnitude',
       'theta_band_magnitude'],
      dtype='object')


In [93]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [97]:
X_train, X_test, y_train, y_test = train_test_split(combined_data[features],
                                                    combined_data['Neuron'],
                                                    test_size=0.2, random_state=42)

RF = RandomForestClassifier(n_estimators=250, random_state=42)
RF.fit(X_train, y_train)
train_pred = RF.predict(X_train)
test_pred = RF.predict(X_test)

print('train precision:', precision_score(y_train,train_pred,average=None))
print('train recall:', recall_score(y_train,train_pred,average=None))
print('train f1:', f1_score(y_train,train_pred,average=None))

print('test precision:', precision_score(y_test,test_pred,average=None))
print('test recall:', recall_score(y_test,test_pred,average=None))
print('test f1:', f1_score(y_test,test_pred,average=None))

print('macro averaged')
print('train precision:', precision_score(y_train,train_pred,average='macro'))
print('train recall:', recall_score(y_train,train_pred,average='macro'))
print('train f1:', f1_score(y_train,train_pred,average='macro'))
print('train accuracy', accuracy_score(y_train,train_pred))

print('test precision:', precision_score(y_test,test_pred,average='macro'))
print('test recall:', recall_score(y_test,test_pred,average='macro'))
print('test f1:', f1_score(y_test,test_pred,average='macro'))
print('test accuracy', accuracy_score(y_test,test_pred))

train precision: [1. 1. 1. 1.]
train recall: [1. 1. 1. 1.]
train f1: [1. 1. 1. 1.]
test precision: [0.83333333 0.55555556 0.75       0.73333333]
test recall: [0.86956522 0.38461538 0.9        0.6875    ]
test f1: [0.85106383 0.45454545 0.81818182 0.70967742]
macro averaged
train precision: 1.0
train recall: 1.0
train f1: 1.0
train accuracy 1.0
test precision: 0.7180555555555556
test recall: 0.7104201505016723
test f1: 0.7083671304673363
test accuracy 0.75


In [89]:
features

Index(['firing rate', 'ifr_mean', 'ifr_skew', 'ifr_kurtosis',
       'ifr_fano_factor', 'baseline_amplitude', 'peak_trough_amplitude',
       'ISI_mean', 'ISI_skew', 'ISI_kurtosis', 'ISI_cv', 'num_bursts',
       'mean_surprise', 'burst_index', 'mean_burst_duration',
       'var_burst_duration', 'mean_interburst_duration',
       'var_interburst_duration', 'max_peak_freq', 'delta_band_magnitude',
       'theta_band_magnitude'],
      dtype='object')

In [90]:
RF.feature_importances_

array([0.08802733, 0.09329092, 0.05104145, 0.03462707, 0.03123744,
       0.06427898, 0.06101794, 0.09359919, 0.04185321, 0.03696402,
       0.04806953, 0.05720707, 0.05705631, 0.05398315, 0.04848607,
       0.03068796, 0.02976692, 0.02576777, 0.00230611, 0.02537848,
       0.02535308])