# Metagenomic-based Diagnostic for Sepsis (Karius| Drop Causal Features)

We asked if we could discriminate sepsis without "confirmed" (i.e. culture-positive) pathogens. The ability to do so using a stringently decontaminated feature space would provide evidence for a polymicrobial theory of sepsis.

In [1]:
# Import Statements
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

cwd = Path.cwd()
datasets = cwd / 'datasets'
results = cwd / 'results/drop_causal_features'

## Data Preprocessing
Since we are using stratified kfold, a validation split is not necesssary.

### Load data

In [2]:
raw_df = pd.read_csv(datasets / 'karius_genus_raw.csv')
display(raw_df)

X = raw_df.iloc[:, 2:].copy()
y = raw_df.iloc[:, 1].copy()

Unnamed: 0,pathogen,y,Bradyrhizobium,Rhodopseudomonas,Bosea,Afipia,Oligotropha,Variibacter,Methylobacterium,Methylorubrum,...,Rubinisphaera,Dictyoglomus,Pakpunavirus,Marinilactibacillus,Paludibacter,Nonagvirus,Halovivax,Phifelvirus,Planktothrix,Denitrobacterium
0,none,healthy,13825.0,108.0,130.0,0.0,12.0,0.0,2883.0,82.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,none,healthy,20476.0,234.0,60.0,53.0,30.0,0.0,8183.0,480.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,none,healthy,9677.0,72.0,17.0,8.0,33.0,0.0,1944.0,436.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,none,healthy,15211.0,158.0,56.0,68.0,11.0,6.0,7081.0,2041.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,none,healthy,56586.0,294.0,181.0,76.0,29.0,12.0,13850.0,1856.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
279,Escherichia coli,septic,9392.0,54.0,23.0,0.0,0.0,19.0,4187.0,199.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
280,Cryptococcus neoformans,septic,3466.0,40.0,80.0,0.0,14.0,11.0,481.0,245.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
281,Streptococcus oralis,septic,17287.0,142.0,25.0,59.0,0.0,0.0,408.0,258.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
282,Escherichia coli,septic,4006.0,39.0,0.0,16.0,0.0,0.0,4456.0,111.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [3]:
# Get features of "causal/confirmed pathogens"
pathogens = raw_df.pathogen
pathogens = pathogens.str.strip().str.split(' ', expand=True)[0].unique()
pathogens = pathogens[pathogens != 'Human']  # Human herpesvirus

# If no fungi
pathogens = pathogens[pathogens != 'Cryptococcus']
pathogens = pathogens[pathogens != 'Candida']
pathogens = pathogens[pathogens != 'none']
pathogens = np.append(pathogens, ['Lymphocryptovirus', 'Simplexvirus'])

display(pathogens)

# Feature space with only these features
X_only = X[pathogens].copy()

# Drop these features from feature space
X = X.drop(pathogens, axis=1)

array(['Escherichia', 'Streptococcus', 'Mycobacterium', 'Cytomegalovirus',
       'Staphylococcus', 'Proteus', 'Klebsiella', 'Pseudomonas',
       'Moraxella', 'Enterococcus', 'Enterobacter', 'Citrobacter',
       'Haemophilus', 'Fusobacterium', 'Salmonella', 'Serratia',
       'Aerococcus', 'Campylobacter', 'Lymphocryptovirus', 'Simplexvirus'],
      dtype=object)

In [4]:
# Drop features from simple decontam feature space
decontaminated_pathogens = pd.read_csv(datasets / 'simple_decontam_pathogens.csv', header=None)[0]
to_keep = list(set(decontaminated_pathogens) - set(pathogens))
X_simple = X.loc[:, to_keep].copy()
print(X_simple.columns)

Index(['Rothia', 'Psychrobacter', 'Anaerococcus', 'Nocardiopsis', 'Pantoea',
       'Stenotrophomonas', 'Cellulosimicrobium', 'Parabacteroides', 'Kluyvera',
       'Actinomadura', 'Kerstersia', 'Laribacter', 'Actinomyces', 'Weissella',
       'Dermabacter', 'Agrobacterium', 'Filifactor', 'Rahnella', 'Aureimonas',
       'Anaerostipes', 'Bacteroides', 'Propionibacterium', 'Aeromonas',
       'Cellulomonas', 'Odoribacter', 'Cronobacter', 'Pandoraea', 'Bacillus',
       'Leptotrichia', 'Treponema', 'Yersinia', 'Desulfovibrio', 'Raoultella',
       'Roseomonas', 'Prevotella', 'Helicobacter', 'Lachnoclostridium',
       'Gryllotalpicola', 'Mastadenovirus', 'Exiguobacterium', 'Tannerella',
       'Lactobacillus', 'Legionella', 'Cardiobacterium', 'Morganella',
       'Paracoccus', 'Gemella', 'Chromobacterium', 'Atlantibacter',
       'Collinsella', 'Mycoplasma', 'Bifidobacterium', 'Actinobacillus',
       'Neisseria', 'Brevibacterium', 'Eikenella', 'Dyella', 'Peptoniphilus',
       'Lelliotti

In [5]:
# Binary encode y
y.loc[y == 'septic'] = 1
y.loc[y == 'healthy'] = 0
y = y.astype('int')

# Relative abundance
X_RA = X.apply(func=lambda x: x / x.sum(), axis=1)
X_only_RA = X_only.apply(func=lambda x: x / x.sum(), axis=1)
X_simple_RA = X_simple.apply(func=lambda x: x / x.sum(), axis=1)

In [6]:
n_splits = 10

pos = len(y[y == 1])
neg = len(y[y == 0])
split_sizes = pd.DataFrame({'Septic': [pos - int(pos / n_splits), int(pos / n_splits)], 
                           'Healthy': [neg - int(neg / n_splits), int(neg / n_splits)]}, index=['Train fold', 'Test fold'])

display(split_sizes)

# Get negative to positive ratio
ratio = sum(y == 0) / sum(y == 1)

Unnamed: 0,Septic,Healthy
Train fold,106,151
Test fold,11,16


## Nested CV for hyperparameter optimisation

In [7]:
from xgboost import XGBClassifier
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score, average_precision_score

In [8]:
def optimise_evaluate(X, y):
    np.random.seed(66)
    ratio = sum(y == 0) / sum(y == 1)
    
    # Hyperparemeter Optimisation using grid search (F1)
    model = XGBClassifier()
    n_estimators = range(100, 500, 10)
    max_depth = range(1, 10, 1)
    gamma = np.linspace(0.1, 3, 10)
    subsample = [0.6, 0.7, 0.8, 0.9, 1.0]
    colsample_bytree = np.linspace(0.1, 1, 20)
    
    param_grid = dict(max_depth=max_depth, 
                      n_estimators=n_estimators, 
                      colsample_bytree=colsample_bytree,
                      gamma=gamma,
                      subsample=subsample,
                      scale_pos_weight=[ratio])
    
    inner_cv = StratifiedKFold(n_splits=n_splits, shuffle=True)
    outer_cv = StratifiedKFold(n_splits=n_splits, shuffle=True)

    # Inner CV
    model = RandomizedSearchCV(model, 
                               param_grid, 
                               scoring='roc_auc',
                               n_iter=1000,
                               n_jobs=10, 
                               cv=inner_cv, 
                               verbose=0)

    model.fit(X, y)
    best_params = model.best_params_
#     print(best_params)

    # Custom metrics
    precision = make_scorer(precision_score, average='binary')
    recall = make_scorer(recall_score, average='binary')
    f1 = make_scorer(f1_score, average='binary')
    auprc = make_scorer(average_precision_score, average=None)
    
    scoring = {'precision': precision, 
               'recall': recall, 
               'AUROC': 'roc_auc',
               'F1': f1}

    # Outer CV
    outer_results = cross_validate(model, X=X, y=y, cv=outer_cv, scoring=scoring)
    outer_results = pd.DataFrame(outer_results).mean()[['test_precision', 'test_recall', 'test_F1', 'test_AUROC']]

    return outer_results, best_params


### Optimise and evaluate models trained on dirty data

In [None]:
raw_results, raw_params = optimise_evaluate(X, y)

RA_results, RA_params = optimise_evaluate(X_RA, y)

only_results, only_params = optimise_evaluate(X_only, y)

only_RA_results, only_RA_params = optimise_evaluate(X_only_RA, y)

simple_results, simple_params = optimise_evaluate(X_simple, y)

simple_RA_results, simple_RA_params = optimise_evaluate(X_simple_RA, y)

## Estimates of test error

In [None]:
metric_df = pd.DataFrame({'Raw': raw_results, 
                          'RA': RA_results, 
                          'Only Causal': only_results,
                          'Only Causal (RA)': only_RA_results,
                          'Simple Decontam': simple_results,
                          'Simple Decontam (RA)': simple_RA_results}).round(3).T
display(metric_df)

## Train dirty models

In [None]:
raw_model = XGBClassifier(**raw_params)
raw_model.fit(X, y)

RA_model = XGBClassifier(**RA_params)
RA_model.fit(X_RA, y)

### Remove Contaminants based on SHAP values

In [None]:
import math
from scipy.stats import spearmanr
import shap


def decontam(X_train, y_train, params):
    X_train = X_train.copy()
    y_train = y_train.copy()
    params = params.copy()
#     X_train = X_train.apply(func=lambda x: x / x.sum(), axis=1)
    
    model = XGBClassifier(**params)
    model.fit(X=X_train, y=y_train)

    explainer = shap.TreeExplainer(model, feature_pertubation='interventional', model_output='probability', data=X_train)
    shap_val = explainer.shap_values(X_train)

    to_retain = np.array([True] * X_train.shape[1])
    corrs = np.zeros(X_train.shape[1])
    
    for i in range(X_train.shape[1]):
        if sum(X_train.iloc[y_train, i] != 0) >= X_train.shape[0] * 0.1:
            rho = spearmanr(X_train.iloc[:, i], shap_val[:, i])[0]
            p = spearmanr(X_train.iloc[:, i], shap_val[:, i])[1]
            
            if rho < 0 and p < 0.05:
                to_retain[i] = False
                
            elif not shap_val[:, i].any(0):
                to_retain[i] = False
        else:
            to_retain[i] = False

    to_retain = X_train.columns[to_retain]
    print(to_retain.shape, to_retain)
    
    return to_retain

In [None]:
# Decontam using raw_params
genera_new = X.columns

for _ in range(10):
    genera_new = decontam(X.loc[:, genera_new], y, raw_params)

### Remove non-human associated pathogens

In [None]:
# Retrieve known human pathogens
meta = pd.read_csv(datasets / 'pathogen_list.csv', encoding= 'unicode_escape')
meta = meta['Genus'].unique()

to_retain = list(set(genera_new).intersection(set(meta)))
print(to_retain)

In [None]:
# Decontam + pathogens
raw_CR = X[to_retain]

# Normalise RA
RA_CR = raw_CR.apply(func=lambda x: x / x.sum(), axis=1)

In [None]:
# Get SHAP summary before removing Cellulomonas and Agrobacterium
pre_model = XGBClassifier(**raw_params)
pre_model.fit(X=raw_CR, y=y)

pre_explainer = shap.TreeExplainer(pre_model, feature_pertubation='interventional', model_output='probability', data=raw_CR)
shap_pre = pre_explainer.shap_values(raw_CR)

shap.summary_plot(shap_pre, raw_CR, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=25)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'pre_shap.png', dpi=600, format='png', bbox_inches='tight')


### Number of Features

In [None]:
print('Neat', X.shape)
print('CR', raw_CR.shape)

## Optimise and evaluate decontaminated models

In [None]:
raw_CR_results, raw_CR_params = optimise_evaluate(raw_CR, y)
# raw_CR_params = {'subsample': 0.7631578947368421, 'scale_pos_weight': 1.4273504273504274, 'n_estimators': 426, 'max_depth': 1, 'gamma': 0.0, 'colsample_bytree': 0.1}

RA_CR_results, RA_CR_params = optimise_evaluate(RA_CR, y)
# RA_CR_params = {'subsample': 0.38421052631578945, 'scale_pos_weight': 1.4273504273504274, 'n_estimators': 101, 'max_depth': 5, 'gamma': 2.894736842105263, 'colsample_bytree': 0.19473684210526315}

metric_df = metric_df.append(pd.DataFrame({'Raw CR': raw_CR_results, 'RA CR': RA_CR_results}).round(3).T)
display(metric_df)

### Fit clean models

In [None]:
raw_CR_model = XGBClassifier(**raw_CR_params)
raw_CR_model.fit(raw_CR, y)

RA_CR_model = XGBClassifier(**RA_CR_params)
RA_CR_model.fit(RA_CR, y)

## Interpreting model using SHAP values

### Plot of SHAP values per Feature

In [None]:
import matplotlib.pyplot as plt
explainer_CR = shap.TreeExplainer(raw_CR_model, feature_pertubation='interventional', model_output='probability', data=raw_CR)
shap_CR = explainer_CR.shap_values(raw_CR)

explainer_raw = shap.TreeExplainer(raw_model, feature_pertubation='interventional', model_output='probability', data=X)
shap_raw = explainer_raw.shap_values(X)

#### Dirty raw

In [None]:
shap.summary_plot(shap_raw, X, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=23)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'raw_shap.png', dpi=600, format='png', bbox_inches='tight')

In [None]:
shap.summary_plot(shap_CR, raw_CR, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=35)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'raw_CR_shap.png', dpi=600, format='png', bbox_inches='tight')

* Features are ranked by importance from top to botttom
* feature values are the kmer counts for each genus
* SHAP values are the average marginal contributions to probability

### Force plot for healthy patient

In [None]:
j = 201

print(f'Actual Classification {y[j]}')
print(raw_CR.index[j])

shap.force_plot(explainer_CR.expected_value, 
                shap_CR[j,:], 
                raw_CR.iloc[j,:],
                show=False,
                matplotlib=True)

plt.savefig(results / 'CR_force_plot.png', dpi=600, format='png', bbox_inches='tight')

### Final Results

In [None]:
display(metric_df)

metric_df.to_csv(results / 'karius_drop_causal_features_results.csv', index=True, header=True)

In [None]:
from sklearn.metrics import roc_auc_score
from sklearn.utils import resample
n_boot = 1000

raw_aurocs = pd.DataFrame()
simple_aurocs = pd.DataFrame()
    
for i in range(n_boot):
    print(f'\riteration {i}', end='')
    X_boot, y_boot = resample(raw_CR, y, stratify=y)
    aucs = X_boot.apply(axis=0, func=lambda x: roc_auc_score(y_score=x, y_true=y_boot))
    
    raw_aurocs = raw_aurocs.append(aucs, ignore_index=True)
    
for i in range(n_boot):
    print(f'\riteration {i}', end='')
    X_boot, y_boot = resample(X_simple, y, stratify=y)
    aucs = X_boot.apply(axis=0, func=lambda x: roc_auc_score(y_score=x, y_true=y_boot))

    simple_aurocs = simple_aurocs.append(aucs, ignore_index=True)

display(raw_aurocs)
display(simple_aurocs)
    
raw_aurocs.to_csv(results / 'raw_CR_no_causal_AUROCS.csv', index=False, header=True)
simple_aurocs.to_csv(results / 'simple_decontam_no_causal_AUROCS.csv', index=False, header=True)