In [1]:
#Step 7 trains and tests both hospitals on the training data using random forest and 10 fold nested cross validation. 

In [2]:
!python3 -m pip install scikit-optimize

[0mCollecting scikit-optimize
  Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting pyaml>=16.9 (from scikit-optimize)
  Downloading pyaml-25.7.0-py3-none-any.whl.metadata (12 kB)
Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl (107 kB)
Downloading pyaml-25.7.0-py3-none-any.whl (26 kB)
[0mInstalling collected packages: pyaml, scikit-optimize
[0mSuccessfully installed pyaml-25.7.0 scikit-optimize-0.10.2


In [25]:
#Import all necessary modules. 
import numpy as np
import pickle
from sklearn.ensemble import RandomForestClassifier
from skopt import BayesSearchCV
import pandas as pd
from sklearn.metrics import precision_recall_curve, auc, roc_curve, confusion_matrix, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupKFold
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d



In [30]:
#Load in the training data feature matrix. 
matrix = pd.read_csv('train_data_.csv')
matrix = matrix.sample(random_state = 2023, frac = 1, ignore_index=True)  

# rop the annotation (annot) column to create X_data.
X_data = matrix.drop('annot', axis=1).reset_index(drop = True)

#Extract the 'BDSPPatientID' and 'annot' columns to create y_data_pre.
y_data_pre = matrix[['BDSPPatientID', 'annot']].reset_index(drop = True)
assert len(X_data) == len(y_data_pre), "DataFrames must have the same length"
X_data['annot']=y_data_pre['annot']
print(X_data)
y_data = X_data[['BDSPPatientID', 'annot']]
print(y_data)
y_data_pre=y_data
print(y_data_pre)
y = y_data_pre['annot']
print(y)
X_data=X_data.drop(['annot'], axis=1)
X=X_data.drop(['BDSPPatientID', 'ContactDate', 'hospital', 'Unnamed: 0', 'NoteFileName', 'Site'], axis=1) 
print(X)

      Unnamed: 0  BDSPPatientID ContactDate  \
0           1328      118166468  2019-02-10   
1           2377      151027760  2018-07-07   
2           2821      151147664  2018-08-25   
3            982      118696581  2022-05-10   
4           1150      114871134  2018-02-14   
...          ...            ...         ...   
1495         732      115864242  2021-04-05   
1496        2556      150772328  2012-07-18   
1497         640      119972331  2019-08-19   
1498        2001      151202679  2015-11-25   
1499        2218      151038873  2015-08-31   

                                   NoteFileName   Site  CT  MRI  acut sdh_pos  \
0     Notes_13376418603_2401628103_20190210.txt    MGB   1    1             0   
1      Notes_1130886747_2345800771_20180707.txt  BIDMC   0    0             0   
2     Notes_1131006147_26220317926_20180825.txt  BIDMC   0    0             0   
3     Notes_13620842071_8922329630_20220510.txt    MGB   1    1             0   
4     Notes_13361437832_163574

In [31]:
# get training from old (<=2018) (>=2020)
# roughly COVID as the split point
# gap = 2 years since shifted day is within +/- 1 year

matrix['ContactDate'] = pd.to_datetime(matrix.ContactDate)
ids1 = matrix.ContactDate.dt.year<=2018
print(ids1.sum(), ids1.mean()*100)
print(matrix.annot[ids1].sum(), matrix.annot[ids1].mean()*100)

ids2 = matrix.ContactDate.dt.year>=2020
print(ids2.sum(), ids2.mean()*100)
print(matrix.annot[ids2].sum(), matrix.annot[ids2].mean()*100)

matrix = matrix[ids1].reset_index(drop=True)
X_data = X_data[ids1].reset_index(drop=True)
X = X[ids1].reset_index(drop=True)
y_data = y_data[ids1].reset_index(drop=True)
y = y[ids1].reset_index(drop=True)
y_data_pre = y_data_pre[ids1].reset_index(drop=True)

963 64.2
244 25.33748701973001
316 21.066666666666666
75 23.734177215189874


In [32]:
from tqdm import tqdm

#Conduct the 10 fold nested cross validation using random forest. 

# Initialize variables for storing results
all_predictions = []
all_true_labels = []
all_row_numbers = []
auc_cv = []
auc_pr = []
f1_cv = []
cf_cv = []
predictions = []
roc_curves = []
pr_curves = []
feature_importances_dict = {feature: [] for feature in X.columns}

#Create a dictionary to store patient IDs for each fold.
fold_patient_ids = {f'fold_{i+1}': {'train': [], 'test': []} for i in range(10)}

#Initialize GroupKFold.
gkf = GroupKFold(n_splits=10)

for cvi, (train_index, test_index) in enumerate(tqdm(gkf.split(X, y, groups=y_data_pre['BDSPPatientID']), total=10)):
    Xtr, Xte = X.loc[train_index], X.loc[test_index]
    ytr, yte = y.loc[train_index], y.loc[test_index]
        
    fold_patient_ids[f'fold_{cvi+1}']['train'].extend(y_data_pre.loc[train_index, 'BDSPPatientID'].tolist())
    fold_patient_ids[f'fold_{cvi+1}']['test'].extend(y_data_pre.loc[test_index, 'BDSPPatientID'].tolist())


    model = RandomForestClassifier(
        n_estimators=100,
        random_state=2023,
        n_jobs=1
    )
    
    search_spaces = {
        'n_estimators': (50, 500),
        'max_depth': (5, 50),
        'min_samples_split': (2, 20),
        'min_samples_leaf': (1, 20),
    }
    
    model_cv = BayesSearchCV(
        model,
        search_spaces,
        n_iter=50,
        scoring='roc_auc',
        n_jobs=4,
        cv=10,
        random_state=2023
    )
    
    model_cv.fit(Xtr, ytr)

    model = model_cv.best_estimator_

    ytr_pred = model.predict_proba(Xtr)[:, 1]
    yte_pred = model.predict_proba(Xte)[:, 1]

    fpr, tpr, cutoffs = roc_curve(ytr, ytr_pred)
    best_cutoff = cutoffs[np.argmax(tpr - fpr)]
    yte_pred_bin = (yte_pred > best_cutoff).astype(int)

    auc_cv.append(roc_auc_score(yte, yte_pred))
    f1_cv.append(f1_score(yte, yte_pred_bin))
    cf_cv.append(confusion_matrix(yte, yte_pred_bin))
    predictions.append(yte_pred_bin)

    model_filename = f'RF_model_train_allhospitals_Notes+ICD+Med_fold{cvi+1}_past.pickle'
    with open(model_filename, 'wb') as f:
        pickle.dump({'model':model, 'cutoff':best_cutoff}, f)


    fpr, tpr, cutoffs = roc_curve(yte, yte_pred)
    all_predictions.extend(yte_pred_bin)
    all_true_labels.extend(yte)
    all_row_numbers.extend(y_data.iloc[test_index]['BDSPPatientID'])

    roc_curves.append((fpr, tpr, roc_auc_score(yte, yte_pred)))
    precision, recall, thresholds = precision_recall_curve(yte, yte_pred)
    pr_curves.append((recall, precision, auc(recall, precision)))
    auc_pr_loop = auc(recall, precision)
    auc_pr.append(auc_pr_loop)

    feature_importances = model.feature_importances_
    for feature, importance in zip(X.columns, feature_importances):
        feature_importances_dict[feature].append(importance)

    print("Feature importances for this fold:")
    print(pd.DataFrame({
        'Feature': X.columns,
        'Importance': feature_importances
    }).sort_values(by='Importance', ascending=False))

    print(auc_cv)
    print(auc_pr)

fold_patient_ids_list = []
for fold, ids in fold_patient_ids.items():
    for train_id in ids['train']:
        fold_patient_ids_list.append({'fold': fold, 'type': 'train', 'patient_id': train_id})
    for test_id in ids['test']:
        fold_patient_ids_list.append({'fold': fold, 'type': 'test', 'patient_id': test_id})
fold_patient_ids_df = pd.DataFrame(fold_patient_ids_list)
fold_patient_ids_df.to_csv('RF_both_hospitals_fold_patient_ids_.csv', index=False)

 10%|████████████                                                                                                             | 1/10 [01:54<17:06, 114.08s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.294785
25        subdur_pos    0.093028
68         ICD_S06.5    0.090251
21  neurosurgeri_pos    0.076549
69         ICD_432.1    0.065381
..               ...         ...
37     burr hole_neg    0.000000
39   craniectomi_neg    0.000000
52           mvc_neg    0.000000
51           mva_neg    0.000000
60         thick_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394]
[0.9199763500663062]


 20%|████████████████████████▏                                                                                                | 2/10 [03:47<15:10, 113.86s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.249795
25        subdur_pos    0.076641
68         ICD_S06.5    0.073919
21  neurosurgeri_pos    0.058464
69         ICD_432.1    0.054112
..               ...         ...
46   chronic sdh_neg    0.000000
37     burr hole_neg    0.000000
51           mva_neg    0.000000
38     stabl sdh_neg    0.000000
58           tbi_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687]
[0.9199763500663062, 0.8667715756397729]


 30%|████████████████████████████████████▎                                                                                    | 3/10 [05:32<12:48, 109.83s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.280754
25        subdur_pos    0.083615
21  neurosurgeri_pos    0.074439
68         ICD_S06.5    0.069301
69         ICD_432.1    0.055585
..               ...         ...
59     tentorium_neg    0.000000
58           tbi_neg    0.000000
38     stabl sdh_neg    0.000000
51           mva_neg    0.000000
43         evacu_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948]


 40%|████████████████████████████████████████████████▍                                                                        | 4/10 [07:28<11:13, 112.25s/it]

Feature importances for this fold:
                         Feature  Importance
24                       sdh_pos    0.320434
21              neurosurgeri_pos    0.090482
25                    subdur_pos    0.088573
68                     ICD_S06.5    0.086041
69                     ICD_432.1    0.051031
..                           ...         ...
58                       tbi_neg    0.000000
49  intraparenchym hemorrhag_neg    0.000000
51                       mva_neg    0.000000
38                 stabl sdh_neg    0.000000
37                 burr hole_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025]


 50%|████████████████████████████████████████████████████████████▌                                                            | 5/10 [09:05<08:53, 106.71s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.297019
25        subdur_pos    0.078974
21  neurosurgeri_pos    0.070376
68         ICD_S06.5    0.069998
69         ICD_432.1    0.049549
..               ...         ...
37     burr hole_neg    0.000000
52           mvc_neg    0.000000
51           mva_neg    0.000000
38     stabl sdh_neg    0.000000
59     tentorium_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882]


 60%|████████████████████████████████████████████████████████████████████████▌                                                | 6/10 [10:55<07:10, 107.67s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.294014
25        subdur_pos    0.089473
68         ICD_S06.5    0.081900
21  neurosurgeri_pos    0.079805
69         ICD_432.1    0.049655
..               ...         ...
58           tbi_neg    0.000000
39   craniectomi_neg    0.000000
51           mva_neg    0.000000
37     burr hole_neg    0.000000
43         evacu_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639, 0.9917460317460318]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882, 0.9714706201095209]


 70%|████████████████████████████████████████████████████████████████████████████████████▋                                    | 7/10 [12:43<05:23, 107.75s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.289550
25        subdur_pos    0.098741
68         ICD_S06.5    0.084839
21  neurosurgeri_pos    0.070697
69         ICD_432.1    0.063461
..               ...         ...
37     burr hole_neg    0.000000
46   chronic sdh_neg    0.000000
43         evacu_neg    0.000000
38     stabl sdh_neg    0.000000
35  brain injuri_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639, 0.9917460317460318, 0.9815724815724816]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882, 0.9714706201095209, 0.8374180800607565]


 80%|████████████████████████████████████████████████████████████████████████████████████████████████▊                        | 8/10 [14:23<03:30, 105.31s/it]

Feature importances for this fold:
             Feature  Importance
24           sdh_pos    0.274319
68         ICD_S06.5    0.102868
25        subdur_pos    0.091890
21  neurosurgeri_pos    0.088144
69         ICD_432.1    0.055263
..               ...         ...
37     burr hole_neg    0.000000
46   chronic sdh_neg    0.000000
43         evacu_neg    0.000000
38     stabl sdh_neg    0.000000
32    resolv sdh_pos    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639, 0.9917460317460318, 0.9815724815724816, 0.9560563380281689]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882, 0.9714706201095209, 0.8374180800607565, 0.8633080290700075]


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▉            | 9/10 [16:02<01:43, 103.39s/it]

Feature importances for this fold:
                         Feature  Importance
24                       sdh_pos    0.296514
25                    subdur_pos    0.109850
21              neurosurgeri_pos    0.097280
68                     ICD_S06.5    0.078166
69                     ICD_432.1    0.056850
..                           ...         ...
49  intraparenchym hemorrhag_neg    0.000000
38                 stabl sdh_neg    0.000000
36                 brain mri_neg    0.000000
37                 burr hole_neg    0.000000
34                  acut sdh_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639, 0.9917460317460318, 0.9815724815724816, 0.9560563380281689, 0.984225352112676]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882, 0.9714706201095209, 0.8374180800607565, 0.8633080290700075, 0.945891377714846]


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [17:38<00:00, 105.87s/it]

Feature importances for this fold:
                         Feature  Importance
24                       sdh_pos    0.325585
25                    subdur_pos    0.096107
21              neurosurgeri_pos    0.093330
68                     ICD_S06.5    0.074135
69                     ICD_432.1    0.055193
..                           ...         ...
39               craniectomi_neg    0.000000
43                     evacu_neg    0.000000
46               chronic sdh_neg    0.000000
49  intraparenchym hemorrhag_neg    0.000000
38                 stabl sdh_neg    0.000000

[71 rows x 2 columns]
[0.9793939393939394, 0.9687830687830687, 0.9582881906825569, 0.9796825396825397, 0.9947478991596639, 0.9917460317460318, 0.9815724815724816, 0.9560563380281689, 0.984225352112676, 0.951154052603328]
[0.9199763500663062, 0.8667715756397729, 0.7758455439988948, 0.9451414503101025, 0.9875585309658882, 0.9714706201095209, 0.8374180800607565, 0.8633080290700075, 0.945891377714846, 0.812741283092627]





In [33]:
#Print the AUROC and AUPRC.
print(np.mean(auc_cv)) 
print(np.mean(auc_pr))

0.9745649893764456
0.8926122841028723
