## Gait Video Study 
### Traditional ML algorithms on task generalization framework 1: W to WT to classify HOA/MS/PD strides and subjects 
#### Remember to add the original count of frames in a single stride (before down sampling via smoothing) for each stride as an additional artificial feature to add information about speed of the subject to the model

1. Save the optimal hyperparameters, confusion matrices and ROC curves for each algorithm.
2. Make sure to not use x, y, z, confidence = 0, 0, 0, 0 as points for the model since they are simply missing values and not data points, so make sure to treat them before inputting to model 
3. Make sure to normalize (mean substract) the features before we feed them to the model.
4. We use the summary statistics as range, CoV and asymmetry between the right and left limbs as the features to input to the traditional models requiring fixed size 1D input for each training/testing set sample.


In [1]:
# 33 subject in total (~10 per group) 
# 4500 strides - 2000 strides - 200 groups for 10 strides per group
# 90 features - 36 Cov, 36 Range, 18 assymetry
# Default + Dimensionality reduction - 3D space
# Try top 10 features 
# Subject generalization is where the overfitting issue is tested - If we get good results, that means we are not 
# overfitting 

In [2]:
from imports import *

In [3]:
path = 'C:\\Users\\Rachneet Kaur\\Box\\Gait Video Project\\GaitVideoData\\video\\'
data_path = path+'traditional_methods_dataframe.csv'

data = pd.read_csv(data_path, index_col= 0)
display(data.head())

Unnamed: 0,key,cohort,trial,scenario,video,PID,stride_number,frame_count,label,right hip-x-CoV,...,ankle-z-asymmetry,heel-x-asymmetry,heel-y-asymmetry,heel-z-asymmetry,toe 1-x-asymmetry,toe 1-y-asymmetry,toe 1-z-asymmetry,toe 2-x-asymmetry,toe 2-y-asymmetry,toe 2-z-asymmetry
0,GVS_212_T_T1_1,HOA,BW,SLWT,GVS_212_T_T1,212,1,46,0,0.046077,...,0.604591,0.233432,0.168252,0.036246,0.031073,0.631236,0.495529,0.26328,0.686741,0.458813
1,GVS_212_T_T1_2,HOA,BW,SLWT,GVS_212_T_T1,212,2,39,0,0.021528,...,0.092555,0.472563,0.293185,0.266283,0.117045,0.7248,0.15791,0.537486,0.338966,0.228945
2,GVS_212_T_T1_3,HOA,BW,SLWT,GVS_212_T_T1,212,3,56,0,0.034394,...,0.058939,0.451513,0.132201,0.132919,0.105341,0.338124,0.179538,0.422522,0.188858,0.082375
3,GVS_212_T_T1_4,HOA,BW,SLWT,GVS_212_T_T1,212,4,53,0,0.028511,...,0.115101,0.299212,0.00191,0.038243,0.027405,0.150478,0.143856,0.039233,0.32659,0.165196
4,GVS_212_T_T1_5,HOA,BW,SLWT,GVS_212_T_T1,212,5,44,0,0.025213,...,0.311598,0.079393,0.000535,0.307031,0.088778,0.117577,0.291998,0.254334,0.005358,0.379653


### Utility functions 

In [23]:
def keep_subjects_common_across_train_test(trial_train, trial_test):
    '''
    
    '''
    print ('Number of subjects in training and test sets:', len(trial_train['PID'].unique()), len(trial_test['PID'].unique()))

    #Try to use same subjects in trials W and WT for testing on same subjects we train on
    print ('Subjects in test set, which are not in training set')
    pids_missing_training = [] #PIDs missing in training set (trial W) but are present in the test set (trial WT)
    for x in trial_test['PID'].unique():
        if x not in trial_train['PID'].unique():
            pids_missing_training.append(x)
    print (pids_missing_training)

    #Deleting the subjects from the test set that are missing in the training set 
    trial_test_reduced = trial_test.set_index('PID').drop(pids_missing_training).reset_index()


    print ('Subjects in training set, which are not in test set')
    pids_missing_test = [] #PIDs missing in test set (trial WT) but are present in the training set (trial W)
    for x in trial_train['PID'].unique():
        if x not in trial_test['PID'].unique():
            pids_missing_test.append(x)
    print (pids_missing_test)

    #Deleting the subjects from the training set that are missing in the test set 
    trial_train_reduced = trial_train.set_index('PID').drop(pids_missing_test).reset_index()

    print ('Number of subjects in training and test sets after reduction:', len(trial_train_reduced['PID'].unique()), \
           len(trial_test_reduced['PID'].unique()))
    return trial_train_reduced, trial_test_reduced 

In [24]:
#Standardize the data before ML methods 
#Take care that testing set is not used while normalizaing the training set, otherwise the train set indirectly contains 
#information about the test set
def normalize(dataframe, n_type): 
    '''
    Input: dataframe, type of normalization (z-score or min-max)
    '''
    col_names = list(dataframe.columns)
    if (n_type == 'z'): #z-score normalization 
        mean = dataframe.mean()
        sd = dataframe.std()
    else: #min-max normalization
        mean = dataframe.min()
        sd = dataframe.max()-dataframe.min()
    return mean, sd

In [None]:
def models(trainX, trainY, testX, testY, model_name = 'random_forest'):
    '''
    training set: trainX, testX
    testing set: testX, testY
    model: model_name
    '''
    trainY1 = trainY['Label'] #Dropping the PID
    
    if(model_name == 'random_forest'): #Random Forest
        grid = {
       'n_estimators': [40,45,50],\
       'max_depth' : [15,20,25,None],\
       'class_weight': [None, 'balanced'],\
       'max_features': ['auto','sqrt','log2', None],\
       'min_samples_leaf':[1,2,0.1,0.05]
        }
        rf_grid = RandomForestClassifier(random_state=0)
        grid_search = GridSearchCV(estimator = rf_grid, param_grid = grid, scoring='accuracy', n_jobs = 1, cv = 5)
    
    if(model_name == 'adaboost'): #Adaboost
        ada_grid = AdaBoostClassifier(random_state=0)
        grid = {
        'n_estimators':[50, 75, 100, 125, 150],\
        'learning_rate':[0.01,.1, 1, 1.5, 2]\
        }
        grid_search = GridSearchCV(ada_grid, param_grid = grid, scoring='accuracy', n_jobs = 1, cv=5)
    
    if(model_name == 'kernel_svm'): #RBF SVM
        svc_grid = SVC(kernel = 'rbf', probability=True, random_state=0)
        grid = {
        'gamma':[0.0001, 0.001, 0.1, 1, 10, ]\
        }
        grid_search = GridSearchCV(svc_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)

    if(model_name == 'gbm'): #GBM
        gbm_grid = GradientBoostingClassifier(random_state=0)
        grid = {
        'learning_rate':[0.15,0.1,0.05], \
        'n_estimators':[50, 100, 150],\
        'max_depth':[2,4,7],\
        'min_samples_split':[2,4], \
        'min_samples_leaf':[1,3],\
        'max_features':[4, 5, 6]\
        }
        grid_search = GridSearchCV(gbm_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
    
    if(model_name=='xgboost'): #Xgboost
        xgb_grid = xgboost.XGBClassifier(random_state=0)
        grid = {
            'min_child_weight': [1, 5],\
            'gamma': [0.1, 0.5, 1, 1.5, 2],\
            'subsample': [0.6, 0.8, 1.0],\
            'colsample_bytree': [0.6, 0.8, 1.0],\
            'max_depth': [5, 7, 8]
        }
        grid_search = GridSearchCV(xgb_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
    
    if(model_name == 'knn'): #KNN
        knn_grid = KNeighborsClassifier()
        grid = {
            'n_neighbors': [1, 3, 4, 5, 10],\
            'p': [1, 2, 3, 4, 5]\
        }
        grid_search = GridSearchCV(knn_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
        
    if(model_name == 'decision_tree'): #Decision Tree
        dec_grid = DecisionTreeClassifier(random_state=0)
        grid = {
            'min_samples_split': range(2, 50),\
        }
        grid_search = GridSearchCV(dec_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
    
    if(model_name == 'linear_svm'): #Linear SVM
        lsvm_grid = LinearSVC(random_state=0)
        grid = {
            'loss': ['hinge','squared_hinge'],\

        }
        grid_search = GridSearchCV(lsvm_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
    
    if(model_name == 'logistic_regression'): #Logistic regression
        grid_search = LogisticRegression(random_state=0)
    
    if(model_name == 'mlp'):
        mlp_grid = MLPClassifier(activation='relu', solver='adam', learning_rate = 'adaptive', learning_rate_init=0.001,\
                                                        shuffle=False, max_iter = 200, random_state = 0)
        grid = {
            'hidden_layer_sizes': [(128, 8, 8, 128, 32), (50, 50, 50, 50, 50, 50, 150, 100, 10), 
                                  (50, 50, 50, 50, 50, 60, 30, 20, 50), (50, 50, 50, 50, 50, 150, 10, 60, 150),
                                  (50, 50, 50, 50, 50, 5, 50, 10, 5), (50, 50, 50, 50, 50, 5, 50, 150, 150),
                                  (50, 50, 50, 50, 50, 5, 30, 50, 20), (50, 50, 50, 50, 10, 150, 20, 20, 30),
                                  (50, 50, 50, 50, 30, 150, 100, 20, 100), (50, 50, 50, 50, 30, 5, 100, 20, 100),
                                  (50, 50, 50, 50, 60, 50, 50, 60, 60), (50, 50, 50, 50, 20, 50, 60, 20, 20),
                                  (50, 50, 50, 10, 50, 10, 150, 60, 150), (50, 50, 50, 10, 50, 150, 30, 150, 5),
                                  (50, 50, 50, 10, 50, 20, 150, 5, 10), (50, 50, 50, 10, 150, 50, 20, 20, 100), 
                                  (50, 50, 50, 30, 100, 5, 30, 150, 30), (50, 50, 50, 50, 100, 150, 100, 200), 
                                  (50, 50, 50, 5, 5, 100, 100, 150), (50, 50, 5, 50, 200, 100, 150, 5), 
                                  (50, 50, 5, 5, 200, 100, 50, 30), (50, 50, 5, 10, 5, 200, 200, 10), 
                                  (50, 50, 5, 30, 5, 5, 50, 10), (50, 50, 5, 200, 50, 5, 5, 50), 
                                  (50, 50,50, 5, 5, 100, 100, 150), (5, 5, 5, 5, 5, 100, 50, 5, 50, 50), 
                                  (5, 5, 5, 5, 5, 100, 20, 100, 30, 30), (5, 5, 5, 5, 5, 20, 20, 5, 30, 100), 
                                  (5, 5, 5, 5, 5, 20, 20, 100, 10, 10), (5, 5, 5, 5, 10, 10, 30, 50, 10, 10), 
                                  (5, 5, 5, 5, 10, 100, 30, 30, 30, 10), (5, 5, 5, 5, 10, 100, 50, 10, 50, 10), 
                                  (5, 5, 5, 5, 10, 100, 20, 100, 30, 5), (5, 5, 5, 5, 30, 5, 20, 30, 100, 50), 
                                  (5, 5, 5, 5, 30, 100, 20, 50, 20, 30), (5, 5, 5, 5, 50, 30, 5, 50, 10, 100), 
                                  (21, 21, 7, 84, 21, 84, 84), (21, 21, 5, 42, 42, 7, 42), (21, 84, 7, 7, 7, 84, 5), 
                                  (21, 7, 84, 5, 5, 21, 120), (42, 5, 21, 21, 21, 5, 120), (42, 5, 42, 84, 7, 120, 84), 
                                  (50, 100, 10, 5, 100, 25), (10, 10, 25, 50, 25, 5), (50, 50, 50, 50, 50, 20, 30, 100, 60)]

        }
        grid_search = GridSearchCV(mlp_grid, param_grid=grid, scoring='accuracy', n_jobs = 1, cv=5)
        
    grid_search.fit(trainX, trainY1) #Fitting on the training set to find the optimal hyperparameters 
#     print('best score: ', grid_search.best_score_)
#     print('best_params: ', grid_search.best_params_, grid_search.best_index_)
#     print('Mean cv accuracy on test set:', grid_search.cv_results_['mean_test_score'][grid_search.best_index_])
#     print('Standard deviation on test set:' , grid_search.cv_results_['std_test_score'][grid_search.best_index_])
#     print('Mean cv accuracy on train set:', grid_search.cv_results_['mean_train_score'][grid_search.best_index_])
#     print('Standard deviation on train set:', grid_search.cv_results_['std_train_score'][grid_search.best_index_])
#     print('Test set performance:\n')
    stride_person_metrics = evaluate(grid_search, testX, testY)
    return stride_person_metrics

In [None]:
def evaluate(model, test_features, trueY):
    test_labels = trueY['Label'] #Dropping the PID
    predictions = model.predict(test_features)
    try:
        prediction_prob = model.predict_proba(test_features)[:, 1] #Score of the class with greater label
    except:
        prediction_prob = model.best_estimator_._predict_proba_lr(test_features)[:, 1] #For linear SVM 
    #Stride wise metrics 
    acc = accuracy_score(test_labels, predictions)
    p = precision_score(test_labels, predictions)
    r = recall_score(test_labels, predictions)
    f1 = f1_score(test_labels, predictions)
    auc = roc_auc_score(test_labels, prediction_prob)
    print('Stride-based model performance: ', acc, p, r, f1, auc)
    
    #For computing person wise metrics 
    temp = copy.deepcopy(trueY) #True label for the stride 
    temp['pred'] = predictions #Predicted label for the stride 
    #Correctly slassified strides i.e. 1 if stride is correctly classified and 0 if otherwise
    temp['correct'] = (temp['Label']==temp['pred'])

    #Proportion of correctly classified strides
    proportion_strides_correct = temp.groupby('PID').aggregate({'correct': 'mean'})  
    proportion_strides_correct['True Label'] = raw_testY.groupby('PID').first() 

    #Label for the person - 0=healthy, 1=MS patient
    proportion_strides_correct['Predicted Label'] = proportion_strides_correct['True Label']*\
    (proportion_strides_correct['correct']>0.5)+(1-proportion_strides_correct['True Label'])*\
    (proportion_strides_correct['correct']<0.5) 

    #Probability of class 1 - MS patient for AUC calculation
    proportion_strides_correct['prob_class1'] = (1-proportion_strides_correct['True Label'])*\
    (1-proportion_strides_correct['correct'])+ proportion_strides_correct['True Label']*proportion_strides_correct['correct'] 
    
    try:
        print (model.best_estimator_)
    except:
        pass
    #Person wise metrics 
    person_acc = accuracy_score(proportion_strides_correct['True Label'], proportion_strides_correct['Predicted Label'])
    person_p = precision_score(proportion_strides_correct['True Label'], proportion_strides_correct['Predicted Label'])
    person_r = recall_score(proportion_strides_correct['True Label'], proportion_strides_correct['Predicted Label'])
    person_f1 = f1_score(proportion_strides_correct['True Label'], proportion_strides_correct['Predicted Label'])
    person_auc = roc_auc_score(proportion_strides_correct['True Label'], proportion_strides_correct['prob_class1'])
    print('Person-based model performance: ', person_acc, person_p, person_r, person_f1, person_auc)
    return proportion_strides_correct['prob_class1'], [acc, p, r, f1, auc, person_acc, person_p, person_r, person_f1, person_auc]

### main() 

In [27]:
#Trial W for training 
trialW = data[data['scenario']=='W']
#Trial WT for testing 
trialWT = data[data['scenario']=='WT']

trialW_reduced, trialWT_reduced = keep_subjects_common_across_train_test(trialW, trialWT)
print ('Number of subjects in training and test sets after reduction:', len(trialW_reduced['PID'].unique()), \
           len(trialWT_reduced['PID'].unique()))

# raw_trainX = raw_trial1.drop(['Label', 'PID', 'TrialID'], axis = 1)
# raw_trainY = raw_trial1[['PID', 'Label']]

# raw_testX = raw_trial2.drop(['Label', 'PID', 'TrialID'], axis = 1)
# raw_testY = raw_trial2[['PID', 'Label']] #PID to compute person based metrics later 

#Normalize according to z-score standardization
# norm_mean, norm_sd = normalize(raw_trainX, 'z')
# raw_trainX_norm = (raw_trainX-norm_mean)/norm_sd
# raw_testX_norm = (raw_testX-norm_mean)/norm_sd

#Total strides and imbalance of labels in the training and testing set
#Training set 
print('Strides in training set: ', len(trialW_reduced))
print ('HOA, MS and PD strides in training set:\n', trialW_reduced['cohort'].value_counts())

#Test Set
print('\nStrides in test set: ', len(trialWT_reduced)) 
print ('HOA, MS and PD strides in test set:\n', trialWT_reduced['cohort'].value_counts())
print ('Imbalance ratio (controls:MS:PD)= 1:X:Y\n', trialWT_reduced['cohort'].value_counts()/trialWT_reduced['cohort'].value_counts()['HOA'])

Number of subjects in training and test sets: 32 26
Subjects in test set, which are not in training set
[403]
Subjects in training set, which are not in test set
[312, 102, 112, 113, 115, 123, 124]
Number of subjects in training and test sets after reduction: 25 25
Number of subjects in training and test sets after reduction: 25 25
Strides in training set:  1128
HOA, MS and PD strides in training set:
 PD     453
MS     341
HOA    334
Name: cohort, dtype: int64

Strides in test set:  1142
HOA, MS and PD strides in test set:
 PD     459
HOA    351
MS     332
Name: cohort, dtype: int64
Imbalance ratio (controls:MS:PD)= 1:X:Y
 PD     1.307692
HOA    1.000000
MS     0.945869
Name: cohort, dtype: float64


In [22]:
trialW_reduced

Unnamed: 0,PID,key,cohort,trial,scenario,video,stride_number,frame_count,label,right hip-x-CoV,...,ankle-z-asymmetry,heel-x-asymmetry,heel-y-asymmetry,heel-z-asymmetry,toe 1-x-asymmetry,toe 1-y-asymmetry,toe 1-z-asymmetry,toe 2-x-asymmetry,toe 2-y-asymmetry,toe 2-z-asymmetry
0,212,GVS_212_W_T2_1,HOA,W,W,GVS_212_W_T2,1,42,0,0.076864,...,0.285971,0.195862,0.720393,0.064473,0.143085,0.677586,0.210210,0.162807,0.677672,0.242728
1,212,GVS_212_W_T2_2,HOA,W,W,GVS_212_W_T2,2,39,0,0.072637,...,0.256597,0.231516,0.654637,0.112397,0.156032,0.448596,0.204625,0.168603,0.468424,0.276266
2,212,GVS_212_W_T2_3,HOA,W,W,GVS_212_W_T2,3,37,0,0.069567,...,0.388825,0.817081,0.108625,0.307391,0.357868,0.160208,0.235414,0.272633,0.086927,0.282987
3,212,GVS_212_W_T2_4,HOA,W,W,GVS_212_W_T2,4,37,0,0.055311,...,0.024362,0.873747,0.435346,0.044738,0.413265,0.565173,0.069416,0.426990,0.557813,0.081200
4,212,GVS_212_W_T2_5,HOA,W,W,GVS_212_W_T2,5,35,0,0.041665,...,0.034621,0.538193,0.194658,0.008901,0.294742,0.269327,0.011823,0.125510,0.355002,0.052064
5,212,GVS_212_W_T2_6,HOA,W,W,GVS_212_W_T2,6,35,0,0.069650,...,0.070450,0.226940,0.065100,0.106075,0.203022,0.346439,0.086403,0.390887,0.350759,0.082130
6,212,GVS_212_W_T2_7,HOA,W,W,GVS_212_W_T2,7,36,0,0.052349,...,0.159524,0.176767,0.284222,0.109569,0.490567,0.263129,0.181062,0.574939,0.268068,0.223443
7,212,GVS_212_W_T2_8,HOA,W,W,GVS_212_W_T2,8,32,0,0.076387,...,0.047522,0.520756,0.138532,0.039676,0.693510,0.024771,0.071739,0.713533,0.040312,0.067053
8,212,GVS_212_W_T2_9,HOA,W,W,GVS_212_W_T2,9,33,0,0.027083,...,0.123473,0.038936,0.069648,0.174583,0.392992,0.110045,0.058682,0.517885,0.085042,0.109160
9,212,GVS_212_W_T2_10,HOA,W,W,GVS_212_W_T2,10,31,0,0.046696,...,0.106067,0.088762,0.045825,0.088616,0.490953,0.258826,0.098502,0.589562,0.211217,0.027978


In [10]:
trialW.groupby(['video']).count()

Unnamed: 0_level_0,cohort,trial,scenario,PID,stride_number,key,frame_count,label
video,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
GVS_102_W_T1,90,90,90,90,90,90,90,90
GVS_112_W_T1,95,95,95,95,95,95,95,95
GVS_113_W_T1,57,57,57,57,57,57,57,57
GVS_115_W_T1,52,52,52,52,52,52,52,52
GVS_123_W_T1,118,118,118,118,118,118,118,118
GVS_124_W_T1,63,63,63,63,63,63,63,63
GVS_212_W_T2,44,44,44,44,44,44,44,44
GVS_213_W_T1,43,43,43,43,43,43,43,43
GVS_214_W_T1,38,38,38,38,38,38,38,38
GVS_215_W_T1,45,45,45,45,45,45,45,45


In [None]:
ml_models = ['random_forest', 'adaboost', 'kernel_svm', 'gbm', 'xgboost', 'knn', 'decision_tree',  'linear_svm', 
             'logistic_regression']
raw_metrics = pd.DataFrame(columns = ml_models) #Dataframe to store accuracies for each ML model for raw data 
#For storing predicted probabilities for person (for class 1) to show ROC curves 
predicted_probs_person_raw = pd.DataFrame(columns = ml_models) 

for ml_model in ml_models:
    print (ml_model)
    predict_probs_person, stride_person_metrics = models(raw_trainX_norm, raw_trainY, raw_testX_norm, raw_testY, ml_model)
    raw_metrics[ml_model] = stride_person_metrics
    predicted_probs_person_raw[ml_model] = predict_probs_person
    print ('********************************')

raw_metrics.index = ['stride_accuracy', 'stride_precision', 'stride_recall', 'stride_F1', 'stride_AUC', 'person_accuracy', 
                     'person_precision', 'person_recall', 'person_F1', 'person_AUC']  
raw_metrics.to_csv(path+'..//trial_generalize//trial_generalize_results_raw_data.csv')
predicted_probs_person_raw.to_csv(path+'..//trial_generalize//trial_generalize_ROCresults_raw_data.csv')

In [None]:
#ROC curves for cohort prediction 
ml_models = ['random_forest',  'kernel_svm',  'xgboost', 'gbm', 'mlp'] 
#, 'adaboost', 'linear_svm', 'decision_tree', 'logistic_regression',] 
#'knn', 
ml_model_names = {'random_forest': 'RF', 'adaboost': 'Adaboost', 'kernel_svm': 'RBF SVM', 'gbm': 'GBM', \
                  'xgboost': 'Xgboost', 'knn': 'KNN', 'decision_tree': 'DT',  'linear_svm': 'LSVM', 
             'logistic_regression': 'LR', 'mlp': 'MLP'}
person_true_labels = raw_testY.groupby('PID').first()
neutral = [0 for _ in range(len(person_true_labels))] # ROC for majority class prediction all the time 

fig, axes = plt.subplots(1, 1, sharex=True, sharey = True, figsize=(5.2, 3.5))
sns.despine(offset=0)
neutral_fpr, neutral_tpr, _ = roc_curve(person_true_labels, neutral) #roc curves
# #Raw Data 
# axes[0].plot(neutral_fpr, neutral_tpr, linestyle='--', label='Majority (AUC = 0.5)', linewidth = 2, color = 'k')
# for ml_model in ml_models:
#     model_probs = predicted_probs_person_raw[ml_model] # person-based prediction probabilities
#     fpr, tpr, _ = roc_curve(person_true_labels, model_probs)
#     axes[0].plot(fpr, tpr, label=ml_model_names[ml_model]+' (AUC = '+ str(round(raw_metrics.loc['person_AUC'][ml_model], 3))
#                  +')', linewidth = 2)
# axes[0].set_ylabel('True Positive Rate')
# axes[0].legend(loc='upper center', bbox_to_anchor=(1.27, 1), ncol=1)
# axes[0].set_title('Raw data')

linestyles = ['-', '-', '-', '-.', '--', '-', '--', '-', '--']
colors = ['b', 'magenta', 'cyan', 'g',  'red', 'violet', 'lime', 'grey', 'pink']

# #SizeN Data 
# axes[0].plot(neutral_fpr, neutral_tpr, linestyle='--', label='Majority (AUC = 0.5)', linewidth = 3, color = 'k')
# for idx, ml_model in enumerate(ml_models):
#     model_probs = predicted_probs_person_sizeN[ml_model] # person-based prediction probabilities
#     fpr, tpr, _ = roc_curve(person_true_labels, model_probs)
#     axes[0].plot(fpr, tpr, label=ml_model_names[ml_model]+' (AUC = '+ str(round(sizeN_metrics.loc['person_AUC'][ml_model], 3))
#                  +')', linewidth = 3, alpha = 0.8, linestyle = linestyles[idx], color = colors[idx])
# # axes[0].legend(loc='upper center', bbox_to_anchor=(1.27, 1), ncol=1)
# axes[0].legend()
# axes[0].set_ylabel('True Positive Rate')
# axes[0].set_title('Size-N data')

#RegressN Data 
axes.plot(neutral_fpr, neutral_tpr, linestyle='--', label='Majority (AUC = 0.5)', linewidth = 3, color = 'k')
for idx, ml_model in enumerate(ml_models):
    model_probs = predicted_probs_person_regressN[ml_model] # person-based prediction probabilities
    fpr, tpr, _ = roc_curve(person_true_labels, model_probs)
    axes.plot(fpr, tpr, label=ml_model_names[ml_model]+' (AUC = '+ str(round(regressN_metrics.loc['person_AUC'][ml_model], 3))
                 +')', linewidth = 3, alpha = 0.8, linestyle = linestyles[idx], color = colors[idx])
axes.set_ylabel('True Positive Rate')
axes.set_title('Cross-task generalization: Regress-N data')
plt.legend()
# axes[1].legend(loc='upper center', bbox_to_anchor=(1.27, 1), ncol=1)

axes.set_xlabel('False Positive Rate')
plt.tight_layout()
plt.savefig(path + '..//trial_generalize//ROC_trial_generalize_onlyregressN.png', dpi = 250)
plt.show()