# Hold-out validation

In [1]:
import xgboost
import shap
print(xgboost.__version__, shap.__version__)

0.90 0.34.0


In [2]:
# Import Statements
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

cwd = Path.cwd()
datasets = cwd / 'datasets'
results = cwd / 'results/pooled'

## Data Preprocessing
Since we are using stratified kfold, a validation split is not necesssary.

### Load data

In [30]:
raw_df = pd.read_csv(datasets / 'kapusta_grumaz_karius_genus_raw.csv')

# Remove NTCs
raw_df = raw_df.loc[raw_df.y != 'ntc', :]
display(raw_df)

# Binary encode y
raw_df.loc[raw_df.y == 'septic', 'y'] = 1
raw_df.loc[raw_df.y == 'healthy', 'y'] = 0
raw_df = raw_df.astype({'y': 'int'})

# Get hold out set
hold_out = 'kapusta'
holdout_df = raw_df.loc[raw_df.dataset == hold_out, :]
other_df = raw_df.loc[raw_df.dataset != hold_out, :]

holdout_X = holdout_df.drop(['y', 'dataset'], axis=1).copy()
holdout_y = holdout_df.y.copy()

other_X = other_df.drop(['y', 'dataset'], axis=1).copy()
other_y = other_df.y.copy()

# Relative abundance
other_X_RA = other_X.apply(func=lambda x: x / x.sum(), axis=1)
holdout_X_RA = holdout_X.apply(func=lambda x: x / x.sum(), axis=1)

display(raw_df)
display(holdout_df)
display(other_df)

Unnamed: 0,dataset,y,Bifidobacterium,Alloscardovia,Arthrobacter,Kocuria,Glutamicibacter,Citricoccus,Micrococcus,Pseudarthrobacter,...,Slackia,Mumia,Thermomonospora,Ilumatobacter,Marinibacterium,Maricaulis,Stella,Eoetvoesia,Bilophila,Paeniclostridium
0,kapusta,healthy,42891.0,1.0,14.0,1.0,3.0,0.0,15.0,1.0,...,0.0,0.0,0.0,7.0,1.0,0.0,0.0,0.0,0.0,0.0
1,kapusta,septic,730.0,0.0,48.0,9.0,10.0,1.0,20.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,kapusta,healthy,36074.0,3.0,6.0,56.0,3.0,0.0,21.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,kapusta,healthy,44094.0,0.0,59.0,10.0,21.0,3.0,66.0,40.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,kapusta,healthy,26958.0,0.0,6.0,1.0,0.0,0.0,16.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
436,karius,septic,4.0,0.0,3.0,2.0,0.0,0.0,1.0,0.0,...,14.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,109.0,0.0
437,karius,septic,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
438,karius,septic,0.0,0.0,0.0,4.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
439,karius,septic,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,dataset,y,Bifidobacterium,Alloscardovia,Arthrobacter,Kocuria,Glutamicibacter,Citricoccus,Micrococcus,Pseudarthrobacter,...,Slackia,Mumia,Thermomonospora,Ilumatobacter,Marinibacterium,Maricaulis,Stella,Eoetvoesia,Bilophila,Paeniclostridium
0,kapusta,0,42891.0,1.0,14.0,1.0,3.0,0.0,15.0,1.0,...,0.0,0.0,0.0,7.0,1.0,0.0,0.0,0.0,0.0,0.0
1,kapusta,1,730.0,0.0,48.0,9.0,10.0,1.0,20.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,kapusta,0,36074.0,3.0,6.0,56.0,3.0,0.0,21.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,kapusta,0,44094.0,0.0,59.0,10.0,21.0,3.0,66.0,40.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,kapusta,0,26958.0,0.0,6.0,1.0,0.0,0.0,16.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
436,karius,1,4.0,0.0,3.0,2.0,0.0,0.0,1.0,0.0,...,14.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,109.0,0.0
437,karius,1,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
438,karius,1,0.0,0.0,0.0,4.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
439,karius,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,dataset,y,Bifidobacterium,Alloscardovia,Arthrobacter,Kocuria,Glutamicibacter,Citricoccus,Micrococcus,Pseudarthrobacter,...,Slackia,Mumia,Thermomonospora,Ilumatobacter,Marinibacterium,Maricaulis,Stella,Eoetvoesia,Bilophila,Paeniclostridium
0,kapusta,0,42891.0,1.0,14.0,1.0,3.0,0.0,15.0,1.0,...,0.0,0.0,0.0,7.0,1.0,0.0,0.0,0.0,0.0,0.0
1,kapusta,1,730.0,0.0,48.0,9.0,10.0,1.0,20.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,kapusta,0,36074.0,3.0,6.0,56.0,3.0,0.0,21.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,kapusta,0,44094.0,0.0,59.0,10.0,21.0,3.0,66.0,40.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,kapusta,0,26958.0,0.0,6.0,1.0,0.0,0.0,16.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
74,kapusta,1,59.0,0.0,586.0,139.0,111.0,4.0,183.0,38.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75,kapusta,1,31.0,0.0,1199.0,350.0,246.0,4.0,771.0,91.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
76,kapusta,1,136.0,0.0,664.0,165.0,114.0,1.0,139.0,50.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
77,kapusta,1,3.0,0.0,696.0,165.0,135.0,3.0,182.0,52.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,dataset,y,Bifidobacterium,Alloscardovia,Arthrobacter,Kocuria,Glutamicibacter,Citricoccus,Micrococcus,Pseudarthrobacter,...,Slackia,Mumia,Thermomonospora,Ilumatobacter,Marinibacterium,Maricaulis,Stella,Eoetvoesia,Bilophila,Paeniclostridium
82,grumaz,0,6.0,0.0,9.0,5.0,2.0,0.0,548.0,0.0,...,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
83,grumaz,0,7.0,0.0,13.0,10.0,2.0,0.0,218.0,1.0,...,0.0,1.0,0.0,3.0,0.0,0.0,1.0,2.0,0.0,0.0
84,grumaz,0,8.0,0.0,8.0,13.0,1.0,0.0,141.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
85,grumaz,0,11.0,0.0,16.0,21.0,5.0,2.0,336.0,0.0,...,0.0,0.0,1.0,0.0,2.0,2.0,0.0,0.0,1.0,1.0
86,grumaz,0,6.0,0.0,3.0,4.0,0.0,1.0,71.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
436,karius,1,4.0,0.0,3.0,2.0,0.0,0.0,1.0,0.0,...,14.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,109.0,0.0
437,karius,1,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
438,karius,1,0.0,0.0,0.0,4.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
439,karius,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Compute train test split size

In [4]:
n_splits = 10

pos = len(other_y[other_y == 1])
neg = len(other_y[other_y == 0])
test_pos = len(holdout_y[holdout_y == 1])
test_neg = len(holdout_y[holdout_y == 0])
split_sizes = pd.DataFrame({'Septic': [pos - int(pos / n_splits), int(pos / n_splits), test_pos], 
                           'Healthy': [neg - int(neg / n_splits), int(neg / n_splits), test_neg]}, 
                           index=['Train fold', 'Validation fold', 'Test fold'])

## Nested CV for hyperparameter optimisation

In [5]:
from xgboost import XGBClassifier
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score, average_precision_score, roc_auc_score

In [6]:
def optimise_evaluate(X, y):
    np.random.seed(66)
    ratio = sum(y == 0) / sum(y == 1)
    
    # Hyperparemeter Optimisation using grid search (F1)
    model = XGBClassifier()
    n_estimators = range(100, 300, 10)
    max_depth = range(1, 5, 1)
    gamma = np.linspace(0.1, 3, 10)
    subsample = [0.6, 0.7, 0.8, 0.9, 1.0]
    colsample_bytree = np.linspace(0.1, 1, 20)
    
    param_grid = dict(max_depth=max_depth, 
                      n_estimators=n_estimators, 
                      colsample_bytree=colsample_bytree,
                      gamma=gamma,
                      subsample=subsample,
                      scale_pos_weight=[ratio])
    
    inner_cv = StratifiedKFold(n_splits=n_splits, shuffle=True)
    outer_cv = StratifiedKFold(n_splits=n_splits, shuffle=True)

    # Inner CV
    model = RandomizedSearchCV(model, 
                               param_grid, 
                               scoring='roc_auc',
                               n_iter=1000,
                               n_jobs=70, 
                               cv=inner_cv, 
                               verbose=0)

    model.fit(X, y)
    best_params = model.best_params_
    print(best_params)

    # Custom metrics
    precision = make_scorer(precision_score, average='binary')
    recall = make_scorer(recall_score, average='binary')
    f1 = make_scorer(f1_score, average='binary')
    auprc = make_scorer(average_precision_score, average=None)
    
    scoring = {'precision': precision, 
               'recall': recall, 
               'AUROC': 'roc_auc',
               'F1': f1}

    # Outer CV
    outer_results = cross_validate(model, X=X, y=y, cv=outer_cv, scoring=scoring)
    outer_results = pd.DataFrame(outer_results).mean()[['test_precision', 'test_recall', 'test_F1', 'test_AUROC']]

    return outer_results, best_params


## Analysis

### 1. Optimise and evaluate clean models trained on Grumaz + Kapusta

In [7]:
raw_results, raw_params = optimise_evaluate(other_X, other_y)
# raw_params = {'subsample': 0.5263157894736842, 'scale_pos_weight': 1.4273504273504274, 'n_estimators': 96, 'max_depth': 2, 'gamma': 1.8421052631578947, 'colsample_bytree': 0.19473684210526315}

RA_results, RA_params = optimise_evaluate(other_X_RA, other_y)
# RA_params = {'subsample': 0.4789473684210527, 'scale_pos_weight': 1.4273504273504274, 'n_estimators': 255, 'max_depth': 1, 'gamma': 0.7894736842105263, 'colsample_bytree': 0.4}


{'subsample': 0.6, 'scale_pos_weight': 1.0632183908045978, 'n_estimators': 230, 'max_depth': 4, 'gamma': 0.7444444444444444, 'colsample_bytree': 0.19473684210526315}
{'subsample': 0.6, 'scale_pos_weight': 1.0632183908045978, 'n_estimators': 200, 'max_depth': 4, 'gamma': 0.7444444444444444, 'colsample_bytree': 0.2894736842105263}


### 2. Decontamination

#### Fit dirty model

In [8]:
raw_model = XGBClassifier(**raw_params)
raw_model.fit(other_X, other_y)

XGBClassifier(colsample_bytree=0.19473684210526315, gamma=0.7444444444444444,
              max_depth=4, n_estimators=230,
              scale_pos_weight=1.0632183908045978, subsample=0.6)

#### Remove non-human associated pathogens

In [9]:
# Retrieve known human pathogens
meta = pd.read_csv(datasets / 'pathogen_list.csv', encoding= 'unicode_escape')
meta = meta['Genus'].unique()

# Remove non-human pathogens
genera_new = other_X.columns
genera_new = list(set(genera_new).intersection(set(meta)))

#### Remove contaminants based on SHAP values

In [10]:
import math
from scipy.stats import spearmanr
import shap


def decontam(X_train, y_train, params):
    X_train = X_train.copy()
    y_train = y_train.copy()
    params = params.copy()
    
    model = XGBClassifier(**params)
    model.fit(X=X_train, y=y_train)

    explainer = shap.TreeExplainer(model, feature_pertubation='interventional', model_output='probability', data=X_train)
    shap_val = explainer.shap_values(X_train)

    to_retain = np.array([True] * X_train.shape[1])
    corrs = np.zeros(X_train.shape[1])
    
    for i in range(X_train.shape[1]):
        if sum(X_train.iloc[:, i] != 0) >= X_train.shape[0] * 0.1:
            rho = spearmanr(X_train.iloc[:, i], shap_val[:, i])[0]
            p = spearmanr(X_train.iloc[:, i], shap_val[:, i])[1]
            print(f'rho={rho}, p={p}, genus={X_train.columns[i]}')
            
            if rho < 0:
                to_retain[i] = False
                
            elif not shap_val[:, i].any(0):
                to_retain[i] = False
        else:
            to_retain[i] = False

    to_retain = X_train.columns[to_retain]
    print(to_retain.shape, to_retain)
    
    return to_retain

In [11]:
# Decontam using raw_params
for _ in range(10):
    genera_new = decontam(other_X.loc[:, genera_new], other_y, raw_params)

rho=-0.4723519030911272, p=2.3739224687114554e-21, genus=Microbacterium
rho=0.8052807953942254, p=4.6925581047623864e-83, genus=Pseudomonas
rho=nan, p=nan, genus=Yersinia
rho=0.14668214674079078, p=0.005359143973223529, genus=Moraxella
rho=0.6345691879047827, p=7.486149714128986e-42, genus=Blautia
rho=-0.640134219309172, p=8.797614338413505e-43, genus=Paenibacillus
rho=nan, p=nan, genus=Leptotrichia
rho=-0.6578392029353188, p=7.145444627176603e-46, genus=Azospirillum
rho=nan, p=nan, genus=Alistipes
rho=0.8883751156443258, p=9.626786333570196e-123, genus=Enterococcus
rho=-0.5883139072784312, p=8.32988189468042e-35, genus=Streptomyces
rho=-0.2554663085412573, p=9.335882399319957e-07, genus=Ochrobactrum
rho=0.6767240225227972, p=2.0831521605363504e-49, genus=Atopobium
rho=0.6514688635360114, p=9.767191628991769e-45, genus=Aureimonas
rho=0.1796557537780522, p=0.0006262432191998217, genus=Clostridioides
rho=0.6740869742884371, p=6.732524913534433e-49, genus=Citrobacter
rho=0.897274647315280

An input array is constant; the correlation coefficent is not defined.


rho=-0.6359976794387318, p=4.338935940132544e-42, genus=Achromobacter
rho=0.2665247635287886, p=2.9704168065089736e-07, genus=Rothia
rho=-0.9105525218427809, p=5.099551085404403e-139, genus=Haemophilus
rho=-0.6008607378970153, p=1.325062035879071e-36, genus=Nocardia
rho=0.8063951845932101, p=1.87809798868656e-83, genus=Pantoea
rho=-0.7861125071409336, p=1.3579612962587204e-76, genus=Mycobacterium
rho=-0.2375599952566932, p=5.345963634543431e-06, genus=Saccharopolyspora
rho=-0.6693283147340716, p=5.423189237129315e-48, genus=Anaerococcus
rho=-0.7859687221517515, p=1.5094589440987974e-76, genus=Corynebacterium
rho=0.7255789061566136, p=6.515141648603231e-60, genus=Dermacoccus
rho=0.7258256091365741, p=5.690328141158989e-60, genus=Salmonella
rho=-0.9024859546807988, p=1.1938703410562928e-132, genus=Ralstonia
rho=0.13916131657548925, p=0.008280960728434325, genus=Acinetobacter
rho=nan, p=nan, genus=Selenomonas
rho=-0.79738552504968, p=2.6126621394988575e-80, genus=Lactobacillus
rho=-0.8795

rho=0.6346762512125808, p=7.187009489369333e-42, genus=Blautia
rho=0.11939607109604047, p=0.0236701206828362, genus=Enterococcus
rho=-0.08919841746279106, p=0.09149957644080314, genus=Atopobium
rho=0.21192382511466662, p=5.177976361679461e-05, genus=Citrobacter
rho=0.26510160935284033, p=3.4522074621417426e-07, genus=Bacteroides
rho=0.15624163339373795, p=0.002994693134728119, genus=Treponema
rho=-0.3180959463606326, p=6.962060360053664e-10, genus=Shewanella
rho=-0.03249037708785384, p=0.5394665422382741, genus=Rhodococcus
rho=0.23045925177676801, p=1.0295807830097008e-05, genus=Cronobacter
rho=0.039064668076622844, p=0.4605933489149069, genus=Prevotella
rho=0.6925909749668728, p=1.372377932014538e-52, genus=Oerskovia
rho=0.5475377734844464, p=1.805164551130277e-29, genus=Stenotrophomonas
rho=-0.06438823497558087, p=0.22360748153940627, genus=Bacillus
rho=0.846909864200253, p=5.896894106680401e-100, genus=Klebsiella
rho=0.4588390029185028, p=4.2984338079526505e-20, genus=Alloprevotella

rho=0.6345240197413055, p=7.616020107329359e-42, genus=Blautia
rho=0.2163150373792929, p=3.57542532130643e-05, genus=Enterococcus
rho=0.1572613568147991, p=0.0028090263407455564, genus=Citrobacter
rho=0.29277062608051774, p=1.5841741982573883e-08, genus=Bacteroides
rho=0.23060838624507612, p=1.015719624543516e-05, genus=Cronobacter
rho=0.6925031630978219, p=1.4310250033607991e-52, genus=Oerskovia
rho=0.5355922107842652, p=4.845273948793799e-28, genus=Stenotrophomonas
rho=0.7654092408676897, p=2.529041464105741e-70, genus=Klebsiella
rho=0.5807431713785637, p=9.298705463222951e-34, genus=Alloprevotella
rho=0.2705922106705296, p=1.9237517950905623e-07, genus=Enterobacter
rho=0.5604402495263396, p=4.441023520172602e-31, genus=Megasphaera
rho=0.06650906719064033, p=0.20869407591019432, genus=Campylobacter
rho=0.06241410833040086, p=0.23815747840381454, genus=Pandoraea
rho=0.40049949780674554, p=2.9075344191337204e-15, genus=Cellulomonas
rho=0.030701604159986937, p=0.5620357735098236, genus=

In [12]:
# Decontam + pathogens
other_X_raw_CR = other_X[genera_new]

# Normalise RA
other_X_RA_CR = other_X_raw_CR.apply(func=lambda x: x / x.sum(), axis=1)

#### Number of features before and after decontamination

In [13]:
print('Neat', other_X.shape)
print('CR', other_X_raw_CR.shape)

Neat (359, 685)
CR (359, 16)


### 3. Optimise and evaluate decontaminated models

In [14]:
raw_CR_results, raw_CR_params = optimise_evaluate(other_X_raw_CR, other_y)

RA_CR_results, RA_CR_params = optimise_evaluate(other_X_RA_CR, other_y)

{'subsample': 0.8, 'scale_pos_weight': 1.0632183908045978, 'n_estimators': 210, 'max_depth': 4, 'gamma': 3.0, 'colsample_bytree': 0.5736842105263158}
{'subsample': 0.8, 'scale_pos_weight': 1.0632183908045978, 'n_estimators': 210, 'max_depth': 2, 'gamma': 0.1, 'colsample_bytree': 0.19473684210526315}


## Estimate test error (hold-out)

In [24]:
def estimate_error(param_dict, x_train, y_train, x_test, y_test):
    model = XGBClassifier(**param_dict)
    model.fit(x_train, y_train)
    
    y_pred = model.predict(x_test)
    y_score = model.predict_proba(x_test)[:, 1]
    
    precision = precision_score(y_true=y_test, y_pred=y_pred)
    recall = recall_score(y_true=y_test, y_pred=y_pred)
    F1 = f1_score(y_true=y_test, y_pred=y_pred)
    auroc = roc_auc_score(y_true=y_test, y_score=y_score)
    auprc = average_precision_score(y_true=y_test, y_score=y_score)
    
    return pd.Series({'external_test_precision': precision,
                      'external_test_recall': recall,
                      'external_test_F1': F1,
                      'external_test_AUROC': auroc,
                      'external_test_AUPRC': auprc})

### Train on Grumaz + Kapusta, test on holdout

#### Preprocess test dataset

In [22]:
holdout_X_raw_CR = holdout_X.loc[:, other_X_raw_CR.columns]
holdout_X_RA_CR = holdout_X_raw_CR.apply(func=lambda x: x / x.sum(), axis=1)

#### Before decontamination

In [25]:
external_error_raw = estimate_error(raw_params, other_X, other_y, holdout_X, holdout_y)
external_error_RA = estimate_error(RA_params, other_X_RA, other_y, holdout_X_RA, holdout_y)

#### After decontamination

In [26]:
external_error_raw_CR = estimate_error(raw_CR_params, other_X_raw_CR, other_y, holdout_X_raw_CR, holdout_y)
external_error_RA_CR = estimate_error(RA_CR_params, other_X_RA_CR, other_y, holdout_X_RA_CR, holdout_y)

In [27]:
#### Combine all results
metric_df = pd.DataFrame({'Raw': raw_results, 'RA': RA_results, 
             'Raw CR': raw_CR_results, 'RA CR': RA_CR_results}).round(3).T

ext_metric_df = pd.DataFrame({'Raw': external_error_raw, 'RA': external_error_RA, 
                              'Raw CR': external_error_raw_CR, 'RA CR': external_error_RA_CR}).round(3).T

final_results = pd.concat([metric_df, ext_metric_df], axis=1)

### Final Results

In [28]:
display(split_sizes)
display(final_results)
final_results.to_csv(results / f'{hold_out}_hold_out_model_results.csv', index=True, header=True)

Unnamed: 0,Septic,Healthy
Train fold,157,167
Validation fold,17,18
Test fold,56,23


Unnamed: 0,test_precision,test_recall,test_F1,test_AUROC,external_test_precision,external_test_recall,external_test_F1,external_test_AUROC,external_test_AUPRC
Raw,0.94,0.925,0.931,0.981,0.4,0.071,0.121,0.248,0.568
RA,0.934,0.931,0.932,0.982,0.521,0.446,0.481,0.043,0.506
Raw CR,0.84,0.799,0.812,0.871,0.733,0.982,0.84,0.703,0.858
RA CR,0.791,0.719,0.743,0.84,0.739,0.911,0.816,0.602,0.752
