# Functional status (mRS) prediction

## 6. Testing

### Favorable functional status (mRS >= 2)
1. with MT data
2. without MT data
### Mortality (mRS 6)
3. with MT data
4. without MT data
### Death/severe disability (mRS 4-6)
5. with MT data
6. without MT data

In [155]:
import pandas as pd
import numpy as np
import os
import pickle

from sklearn.utils import resample
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, accuracy_score
from sklearn.compose import make_column_selector as selector

import shap

import plotly.graph_objects as go

In [7]:
# helper function for loading/testing model on bootstrapped samples and saving results

def test_final_model(X, y, model_fname, n_iter, result_fname):
    
    # load model for testing
    final_model = pickle.load(open(model_fname, 'rb'))
    
    # create empty results dict
    results = {'auc': [], 'prec': [], 'recall': [], 'f1': [], 'acc': []}
    
    # bootstrap resample
    for i in range(n_iter):
        
        if i % 50 == 0:
            print('evaluating iteration no. {}...'.format(i))
        
        X_test_resampled, y_true = resample(X, y)
        
        auc = roc_auc_score(y_true, final_model.predict_proba(X_test_resampled)[:, 1])
        results['auc'].append(auc)

        y_pred = final_model.predict(X_test_resampled)

        results['prec'].append(precision_score(y_true, y_pred, zero_division = 0))
        results['recall'].append(recall_score(y_true, y_pred))
        results['f1'].append(f1_score(y_true, y_pred))
        results['acc'].append(accuracy_score(y_true, y_pred))    
    
    # create data frame for results and save
    pd.DataFrame(results).to_pickle(result_fname)
    
    return pd.DataFrame(results)

In [148]:
# helper function to get SHAP values and save them

def get_shap_values(model_fname, X, analysis_dir):
    
    cont_columns_selector = selector(dtype_exclude = object)
    cat_columns_selector = selector(dtype_include = object)
    cont_columns = cont_columns_selector(X)
    cat_columns = cat_columns_selector(X)
    all_columns = list(X.columns)
    
    model = pickle.load(open(model_fname, 'rb'))
    
    # get names of the continuous features selected in the pipeline
    final_cont_columns = model['preprocessor'].transformers_[0][1].named_steps['variance_threshold'].get_feature_names_out(cont_columns).tolist()
    
    # get names of the categorical features selected in the pipeline
    all_cat_columns = model['preprocessor'].transformers_[1][1].named_steps['one hot encoder'].get_feature_names(cat_columns).tolist()
    cat_indices = model['preprocessor'].transformers_[1][1].named_steps['selector'].get_support(indices = True)
    final_cat_columns = list(all_cat_columns[i] for i in cat_indices)
    feature_names = final_cont_columns + final_cat_columns
    
    # get SHAP values
    explainer = shap.TreeExplainer(model['classifier'])
    observations = model['preprocessor'].transform(X)
    observations = pd.DataFrame(observations, columns = final_cont_columns + final_cat_columns)
    
    shap_values = explainer.shap_values(observations)
    shap_values_for_beeswarm = explainer(observations) # beeswarm plot uses different output
    
    # save SHAP values and feature names
    with open(os.path.join(analysis_dir, 'shap_values.pkl'), 'wb') as f:
        pickle.dump(shap_values, f)
    with open(os.path.join(analysis_dir, 'shap_values_for_beeswarm.pkl'), 'wb') as f:
        pickle.dump(shap_values_for_beeswarm, f)
    with open(os.path.join(analysis_dir, 'feature_names.pkl'), 'wb') as f:
        pickle.dump(feature_names, f)
    
    return shap_values, shap_values_for_beeswarm, feature_names

### 1. Favorable functional status prediction - with MT data

In [18]:
X_test_mt = pd.read_pickle('transformed_datasets/fav_functional_status/mt_data/X_test_trans_mt.pkl')
y_test = np.load('transformed_datasets/fav_functional_status/y_test_trans.npy')

In [19]:
mt_results_rf = test_final_model(X = X_test_mt, 
                             y = y_test, 
                             model_fname = 'models/fav_functional_status/mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/fav_functional_status/mt_data/mt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [20]:
mt_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.796958,0.710343,0.583529,0.636509,0.734292
std,0.057158,0.097029,0.091478,0.080186,0.052441
min,0.60625,0.357143,0.208333,0.263158,0.569444
25%,0.760861,0.642857,0.52,0.581818,0.694444
50%,0.801924,0.714286,0.583333,0.641509,0.736111
75%,0.835584,0.777778,0.647059,0.692308,0.763889
max,0.951587,0.965517,0.928571,0.852459,0.875


In [149]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/fav_functional_status/mt_data/final_rf_model.pkl', 
    X = X_test_mt, 
    analysis_dir = 'analysis/fav_functional_status/mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [35]:
mt_results_lr = test_final_model(X = X_test_mt, 
                             y = y_test, 
                             model_fname = 'models/fav_functional_status/mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/fav_functional_status/mt_data/mt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [36]:
mt_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.771555,0.653905,0.584808,0.612988,0.707528
std,0.061547,0.093861,0.092771,0.077849,0.053805
min,0.561688,0.35,0.32,0.347826,0.527778
25%,0.730826,0.592593,0.52,0.561404,0.666667
50%,0.772891,0.653846,0.586207,0.615385,0.708333
75%,0.815151,0.72,0.645161,0.666667,0.75
max,0.925926,0.92,0.857143,0.823529,0.875


### 2. Favorable functional status - no MT data

In [15]:
X_test_nomt = pd.read_pickle('transformed_datasets/fav_functional_status/no_mt_data/X_test_trans_nomt.pkl')
y_test = np.load('transformed_datasets/fav_functional_status/y_test_trans.npy')

In [21]:
nomt_results_rf = test_final_model(X = X_test_nomt, 
                             y = y_test, 
                             model_fname = 'models/fav_functional_status/no_mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/fav_functional_status/no_mt_data/nomt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [22]:
nomt_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.73456,0.633983,0.486501,0.54612,0.681431
std,0.060801,0.103315,0.090844,0.084016,0.055065
min,0.528125,0.304348,0.217391,0.277778,0.5
25%,0.696301,0.565217,0.423077,0.490566,0.638889
50%,0.732102,0.636364,0.483871,0.551724,0.680556
75%,0.778231,0.704248,0.548387,0.603774,0.722222
max,0.903906,0.9,0.791667,0.785714,0.847222


In [150]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/fav_functional_status/no_mt_data/final_rf_model.pkl', 
    X = X_test_nomt, 
    analysis_dir = 'analysis/fav_functional_status/no_mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [37]:
nomt_results_lr = test_final_model(X = X_test_nomt, 
                             y = y_test, 
                             model_fname = 'models/fav_functional_status/no_mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/fav_functional_status/no_mt_data/nomt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [38]:
nomt_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.745547,0.639589,0.548636,0.58584,0.695056
std,0.06305,0.095314,0.090327,0.07709,0.05302
min,0.530041,0.304348,0.21875,0.266667,0.541667
25%,0.706375,0.576923,0.484848,0.535714,0.663194
50%,0.746753,0.641429,0.545455,0.588235,0.694444
75%,0.789372,0.708333,0.608696,0.638298,0.736111
max,0.919732,0.925926,0.88,0.846154,0.888889


### 3. Mortality - with MT data

In [23]:
X_test_mort_mt = pd.read_pickle('transformed_datasets/mortality/mt_data/X_test_trans_mt.pkl')
y_test_mort = np.load('transformed_datasets/mortality/y_test_trans.npy')

In [24]:
mt_mort_results_rf = test_final_model(X = X_test_mort_mt, 
                             y = y_test_mort, 
                             model_fname = 'models/mortality/mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/mortality/mt_data/mt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [25]:
mt_mort_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.70818,0.449709,0.531306,0.480319,0.735417
std,0.074585,0.112484,0.127385,0.104053,0.052715
min,0.435463,0.071429,0.0625,0.066667,0.541667
25%,0.65843,0.375,0.444444,0.413793,0.708333
50%,0.712035,0.444444,0.529412,0.484848,0.736111
75%,0.760054,0.526316,0.619048,0.553191,0.777778
max,0.929012,0.823529,0.916667,0.789474,0.888889


In [151]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/mortality/mt_data/final_rf_model.pkl', 
    X = X_test_mort_mt, 
    analysis_dir = 'analysis/mortality/mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [39]:
mt_mort_results_lr = test_final_model(X = X_test_mort_mt, 
                             y = y_test_mort, 
                             model_fname = 'models/mortality/mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/mortality/mt_data/mt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [40]:
mt_mort_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.677568,0.396676,0.530624,0.448014,0.698514
std,0.081401,0.103043,0.119822,0.09759,0.052865
min,0.438272,0.083333,0.125,0.1,0.555556
25%,0.622234,0.32,0.45,0.387097,0.666667
50%,0.678412,0.392857,0.533333,0.45,0.694444
75%,0.735645,0.461538,0.615385,0.514286,0.736111
max,0.904094,0.722222,0.933333,0.723404,0.847222


### 4. Mortality - without MT data

In [26]:
X_test_mort_nomt = pd.read_pickle('transformed_datasets/mortality/no_mt_data/X_test_trans_nomt.pkl')
y_test_mort = np.load('transformed_datasets/mortality/y_test_trans.npy')

In [27]:
nomt_mort_results_rf = test_final_model(X = X_test_mort_nomt, 
                             y = y_test_mort, 
                             model_fname = 'models/mortality/no_mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/mortality/no_mt_data/nomt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [28]:
nomt_mort_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.770519,0.410698,0.412208,0.403998,0.722431
std,0.061203,0.122479,0.12527,0.109454,0.054195
min,0.53202,0.047619,0.071429,0.068966,0.5
25%,0.731837,0.333333,0.333333,0.333333,0.694444
50%,0.775447,0.411765,0.411765,0.410256,0.722222
75%,0.811536,0.5,0.5,0.482759,0.763889
max,0.928655,0.785714,0.888889,0.736842,0.861111


In [152]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/mortality/no_mt_data/final_rf_model.pkl', 
    X = X_test_mort_nomt, 
    analysis_dir = 'analysis/mortality/no_mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [41]:
nomt_mort_results_lr = test_final_model(X = X_test_mort_nomt, 
                             y = y_test_mort, 
                             model_fname = 'models/mortality/no_mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/mortality/no_mt_data/nomt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [42]:
nomt_mort_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.739846,0.456116,0.65298,0.53045,0.737042
std,0.071363,0.107123,0.118976,0.09836,0.052394
min,0.416667,0.04,0.083333,0.060606,0.472222
25%,0.693048,0.391304,0.578947,0.470588,0.708333
50%,0.744457,0.454545,0.666667,0.533333,0.736111
75%,0.788364,0.521739,0.728147,0.595745,0.777778
max,0.925806,0.833333,1.0,0.851064,0.902778


### 5. Death/severe disability - with MT data

In [29]:
X_test_dsd_mt = pd.read_pickle('transformed_datasets/dsd/mt_data/X_test_trans_mt.pkl')
y_test_dsd = np.load('transformed_datasets/dsd/y_test_trans.npy')

In [30]:
mt_dsd_results_rf = test_final_model(X = X_test_dsd_mt, 
                             y = y_test_dsd, 
                             model_fname = 'models/dsd/mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/dsd/mt_data/mt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [31]:
mt_dsd_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.812146,0.745586,0.841461,0.788639,0.764875
std,0.054874,0.066424,0.061001,0.05173,0.051091
min,0.630469,0.512195,0.6,0.597015,0.597222
25%,0.775926,0.7,0.804878,0.756098,0.736111
50%,0.815625,0.75,0.844444,0.790123,0.763889
75%,0.850816,0.791667,0.882695,0.825,0.805556
max,0.96517,0.911765,1.0,0.915663,0.902778


In [153]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/dsd/mt_data/final_rf_model.pkl', 
    X = X_test_dsd_mt, 
    analysis_dir = 'analysis/dsd/mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [43]:
mt_dsd_results_lr = test_final_model(X = X_test_dsd_mt, 
                             y = y_test_dsd, 
                             model_fname = 'models/dsd/mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/dsd/mt_data/mt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [44]:
mt_dsd_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.778395,0.759211,0.734921,0.744517,0.736069
std,0.054367,0.072921,0.066748,0.056509,0.051228
min,0.516049,0.53125,0.422222,0.520548,0.513889
25%,0.742823,0.714286,0.690476,0.705882,0.708333
50%,0.780591,0.763158,0.738095,0.746988,0.736111
75%,0.815627,0.80722,0.78125,0.785714,0.777778
max,0.932337,0.948718,0.928571,0.891892,0.888889


### 6. Death/severe disability - without MT data

In [32]:
X_test_dsd_nomt = pd.read_pickle('transformed_datasets/dsd/no_mt_data/X_test_trans_nomt.pkl')
y_test_dsd = np.load('transformed_datasets/dsd/y_test_trans.npy')

In [33]:
nomt_dsd_results_rf = test_final_model(X = X_test_dsd_nomt, 
                             y = y_test_dsd, 
                             model_fname = 'models/dsd/no_mt_data/final_rf_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/dsd/no_mt_data/mt_results_rf.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [34]:
nomt_dsd_results_rf.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.763576,0.67702,0.71228,0.6917,0.669125
std,0.054535,0.073595,0.069459,0.059085,0.056064
min,0.533333,0.4375,0.5,0.466667,0.486111
25%,0.727404,0.627907,0.666667,0.650452,0.638889
50%,0.767244,0.681818,0.714286,0.693333,0.666667
75%,0.802333,0.727273,0.757576,0.733333,0.708333
max,0.923077,0.894737,0.925,0.866667,0.861111


In [154]:
shap_values, shap_values_for_beeswarm, feature_names = get_shap_values(
    model_fname = 'models/dsd/no_mt_data/final_rf_model.pkl', 
    X = X_test_dsd_nomt, 
    analysis_dir = 'analysis/dsd/no_mt_data')

Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.


In [45]:
nomt_dsd_results_lr = test_final_model(X = X_test_dsd_nomt, 
                             y = y_test_dsd, 
                             model_fname = 'models/dsd/no_mt_data/final_lr_model.pkl', 
                             n_iter = 1000, 
                             result_fname = 'test_results/dsd/no_mt_data/mt_results_lr.pkl')

evaluating iteration no. 0...
evaluating iteration no. 50...
evaluating iteration no. 100...
evaluating iteration no. 150...
evaluating iteration no. 200...
evaluating iteration no. 250...
evaluating iteration no. 300...
evaluating iteration no. 350...
evaluating iteration no. 400...
evaluating iteration no. 450...
evaluating iteration no. 500...
evaluating iteration no. 550...
evaluating iteration no. 600...
evaluating iteration no. 650...
evaluating iteration no. 700...
evaluating iteration no. 750...
evaluating iteration no. 800...
evaluating iteration no. 850...
evaluating iteration no. 900...
evaluating iteration no. 950...


In [46]:
nomt_dsd_results_lr.describe()

Unnamed: 0,auc,prec,recall,f1,acc
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,0.753553,0.683031,0.786913,0.728897,0.693125
std,0.057676,0.07175,0.064536,0.055466,0.053455
min,0.585227,0.463415,0.59375,0.550725,0.555556
25%,0.712842,0.634146,0.742627,0.69207,0.652778
50%,0.754496,0.682927,0.787879,0.734177,0.694444
75%,0.794182,0.734694,0.833333,0.767442,0.722222
max,0.922601,0.893617,0.972222,0.897959,0.861111
