# Metagenomic-based Diagnostic for Sepsis

In [1]:
# Import Statements
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV, StratifiedKFold, cross_val_score
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

cwd = Path.cwd()
datasets = cwd / 'datasets'
results = cwd / 'results'

## Data Preprocessing
Since we are using stratified kfold, a validation split is not necesssary.

### Load data

In [2]:
raw_df = pd.read_csv(datasets / 'karius_genus_raw.csv')
pathogens_df = pd.read_csv(datasets / 'karius_genus_pathogens.csv')
display(raw_df)
display(pathogens_df)

X = raw_df.iloc[:, 2:].copy()
X_pathogens = pathogens_df.iloc[:, 2:].copy()
y = raw_df.iloc[:, 1].copy()

Unnamed: 0,pathogen,y,Bradyrhizobium,Rhodopseudomonas,Bosea,Afipia,Oligotropha,Variibacter,Methylobacterium,Methylorubrum,...,Rubinisphaera,Dictyoglomus,Pakpunavirus,Marinilactibacillus,Paludibacter,Nonagvirus,Halovivax,Phifelvirus,Planktothrix,Denitrobacterium
0,none,healthy,13825.0,108.0,130.0,0.0,12.0,0.0,2883.0,82.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,none,healthy,20476.0,234.0,60.0,53.0,30.0,0.0,8183.0,480.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,none,healthy,9677.0,72.0,17.0,8.0,33.0,0.0,1944.0,436.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,none,healthy,15211.0,158.0,56.0,68.0,11.0,6.0,7081.0,2041.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,none,healthy,56586.0,294.0,181.0,76.0,29.0,12.0,13850.0,1856.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
279,Escherichia coli,septic,9392.0,54.0,23.0,0.0,0.0,19.0,4187.0,199.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
280,Cryptococcus neoformans,septic,3466.0,40.0,80.0,0.0,14.0,11.0,481.0,245.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
281,Streptococcus oralis,septic,17287.0,142.0,25.0,59.0,0.0,0.0,408.0,258.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
282,Escherichia coli,septic,4006.0,39.0,0.0,16.0,0.0,0.0,4456.0,111.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,pathogen,y,Saccharopolyspora,Lactococcus,Kocuria,Brevundimonas,Acidaminococcus,Propionibacterium,Oligella,Streptococcus,...,Brevibacterium,Alistipes,Buttiauxella,Bosea,Escherichia,Selenomonas,Lymphocryptovirus,Raoultella,Atopobium,Stenotrophomonas
0,none,healthy,0.0,133.0,30.0,50.0,0.0,0.0,0.0,227.0,...,0.0,0.0,0.0,130.0,0.0,0.0,0.0,0.0,0.0,353.0
1,none,healthy,0.0,78.0,14.0,27.0,0.0,0.0,0.0,147.0,...,18.0,15.0,0.0,60.0,11.0,0.0,0.0,0.0,0.0,987.0
2,none,healthy,0.0,0.0,1.0,15.0,0.0,0.0,0.0,115.0,...,0.0,0.0,0.0,17.0,0.0,0.0,0.0,0.0,0.0,630.0
3,none,healthy,0.0,0.0,9.0,39.0,0.0,0.0,0.0,337.0,...,34.0,0.0,0.0,56.0,0.0,0.0,0.0,0.0,0.0,916.0
4,none,healthy,0.0,45.0,125.0,50.0,0.0,6.0,0.0,367.0,...,3.0,48.0,0.0,181.0,87.0,0.0,187.0,0.0,0.0,1416.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
279,Escherichia coli,septic,0.0,84.0,57.0,48.0,0.0,0.0,0.0,737.0,...,0.0,190.0,0.0,23.0,2228.0,0.0,0.0,0.0,14.0,2115.0
280,Cryptococcus neoformans,septic,0.0,20.0,0.0,40.0,0.0,0.0,0.0,614.0,...,0.0,0.0,0.0,80.0,2.0,27.0,0.0,0.0,0.0,1131.0
281,Streptococcus oralis,septic,0.0,136.0,59.0,127.0,0.0,0.0,0.0,26874.0,...,0.0,0.0,0.0,25.0,0.0,0.0,0.0,0.0,0.0,651.0
282,Escherichia coli,septic,2.0,0.0,0.0,2.0,0.0,0.0,0.0,94.0,...,0.0,0.0,0.0,0.0,293.0,0.0,0.0,0.0,0.0,1010.0


In [3]:
# Binary encode y
y.loc[y == 'septic'] = 1
y.loc[y == 'healthy'] = 0
y = y.astype('int')

# Relative abundance
RA = X.apply(func=lambda x: x / x.sum(), axis=1)

### Train Test Split

In [None]:
from sklearn.model_selection import train_test_split

# Test data
X_idx, X_test_idx, y_train, y_test = train_test_split(X.index, y,
                                                shuffle=True,
                                                random_state=66,
                                                test_size=0.3,
                                                stratify=y)


def get_datasets(df, train_idx, test_idx):
    train = df.loc[train_idx, :]
    test = df.loc[test_idx, :]
    return train, test

raw_train, raw_test = get_datasets(X, X_idx, X_test_idx)
RA_train, RA_test = get_datasets(RA, X_idx, X_test_idx)

Here I print the number of examples for each split of data.

In [None]:
def get_metadata(y, df):
    pos = len(y[y == 1])
    neg = len(y[y == 0])
    row = pd.DataFrame({'Septic': [pos], 'Healthy': [neg]})
    df = df.append(row)
    return df


metadata = pd.DataFrame(columns=['Septic', 'Healthy'])
metadata = get_metadata(y_train, metadata)
metadata = get_metadata(y_test, metadata)
metadata.index = ['train', 'test']
display(metadata)

# Get negative to positive ratio
ratio = sum(y == 0) / sum(y == 1)

## Optimising model hyperparameters
We use gridsearch to do hyperparameter optimisation (n_estimators, max_depth).
n_estimators: Number of tree stumps
max_depth: max depth of nodes in tree stumps

In [None]:
def optimise(X, y):
    # Hyperparemeter Optimisation using grid search (F1)
    model = XGBClassifier()
    n_estimators = range(50, 300, 10)
    max_depth = range(1, 5, 1)
    gamma = [0.5, 1, 1.5, 2, 5]
    subsample = [0.6, 0.7, 0.8, 0.9, 1.0]
    colsample_bytree = np.linspace(0.1, 1, 10)
    print(max_depth)
    param_grid = dict(max_depth=max_depth, 
                      n_estimators=n_estimators, 
                      colsample_bytree=colsample_bytree,
                      gamma = gamma,
                     subsample = subsample,
                     scale_pos_weight = [ratio])
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=66)
    grid_search = GridSearchCV(model, param_grid, scoring="roc_auc", n_jobs=-1, cv=kfold, verbose=1)
    grid_result = grid_search.fit(X, y)
    # summarize results
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']
    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

    # Print parameters and scores
#     for mean, stdev, param in zip(means, stds, params):
#         print("%f (%f) with: %r" % (mean, stdev, param))

    return(grid_result.best_params_)

### Optimise Model using Neat Data

In [None]:
# raw_params = optimise(raw_train, y_train)
raw_params = {'colsample_bytree': 0.1, 'gamma': 0.5, 'max_depth': 2, 'n_estimators': 140, 'scale_pos_weight': 1.4273504273504274, 'subsample': 0.6}

# RA_params = optimise(RA_train, y_train)
RA_params = {'colsample_bytree': 0.1, 'gamma': 1, 'max_depth': 1, 'n_estimators': 290, 'scale_pos_weight': 1.4273504273504274, 'subsample': 0.6}

### Model Training

In [None]:
raw_model = XGBClassifier(**raw_params)
raw_model.fit(X=raw_train, y=y_train)

RA_model = XGBClassifier(**RA_params)
RA_model.fit(X=RA_train, y=y_train)

### Evaluate Neat Data

In [None]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from imblearn.metrics import sensitivity_specificity_support

def evaluate(model, X_Test, y_Test):
    y_Pred = model.predict(X_Test)
    y_Score = model.predict_proba(X_Test)[:, 1]

    precision, recall, f1, _ = precision_recall_fscore_support(y_true=y_Test, y_pred=y_Pred, average='binary')
    sensitivity, specificity, _ = sensitivity_specificity_support(y_true=y_Test, y_pred=y_Pred, average='binary')
    auc = roc_auc_score(y_true=y_Test, y_score=y_Score)
    metric_df = pd.DataFrame({'F1':[f1],
                              'Sensitivity': [sensitivity],
                              'Specificity': [specificity],
                              'AUROC': [auc]})
    return metric_df


# Evaluate on test data
raw_metric = evaluate(raw_model, raw_test, y_test)
RA_metric = evaluate(RA_model, RA_test, y_test)

### Remove Contaminants based on SHAP values

In [None]:
import math
from scipy.stats import spearmanr
import shap


def decontam(X_train, y_train, params):
    model = XGBClassifier(**params)
    model.fit(X=X_train, y=y_train)

    explainer = shap.TreeExplainer(model, feature_pertubation='interventional', model_output='probability', data=X_train)
    shap_val = explainer.shap_values(X_train)

    to_retain = np.array([True] * X_train.shape[1])
    corrs = np.zeros(X_train.shape[1])

    for i in range(X_train.shape[1]):
        rho = spearmanr(X_train.iloc[:, i], shap_val[:, i])[0]
        p = spearmanr(X_train.iloc[:, i], shap_val[:, i])[1]
        if rho < 0 and p < 0.05:
            to_retain[i] = False

        if math.isnan(rho):
            corrs[i] = 2
        else:
            corrs[i] = rho

    to_retain = np.logical_and(corrs > 0, corrs != 2)
    to_retain = X_train.columns[to_retain]
    print(to_retain.shape, to_retain)
    return to_retain

In [None]:
genera_new = raw_train.columns
for _ in range(10):
    genera_new = decontam(raw_train[genera_new], y_train, raw_params)

In [None]:
to_retain = list(set(genera_new).intersection(set(X_pathogens.columns)))
print(to_retain)

In [None]:
# Decontam + pathogens
raw_SS = X_pathogens[to_retain]


In [None]:
# Get SHAP summary before removing Cellulomonas and Agrobacterium
pre_SS_train, pre_SS_test = get_datasets(raw_SS, X_idx, X_test_idx)
pre_model = XGBClassifier(**raw_params)
pre_model.fit(X=pre_SS_train, y=y_train)

pre_explainer = shap.TreeExplainer(pre_model, feature_pertubation='interventional', model_output='probability', data=pre_SS_train)
shap_pre = pre_explainer.shap_values(pre_SS_train)

shap.summary_plot(shap_pre, pre_SS_train, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=25)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'pre_shap.png', dpi=1200, format='png', bbox_inches='tight')


### Drop features

In [None]:
raw_SS = raw_SS.drop(['Cellulomonas', 'Agrobacterium'], axis=1)
print(f'Removed {len(genera_new) - len(to_retain)} genera')
# print(f"Genera from list not removed:\n {set(meta) - set(to_remove)}")
print(f'New shape  = {raw_SS.shape}')

In [None]:
# Normalise Datasets
RA_SS = raw_SS.apply(func=lambda x: x / x.sum(), axis=1)

# Get train test split
raw_SS_train, raw_SS_test = get_datasets(raw_SS, X_idx, X_test_idx)
RA_SS_train, RA_SS_test = get_datasets(RA_SS, X_idx, X_test_idx)

### Number of Features

In [None]:
print('Neat', X.shape)
print('Pathogens', raw_SS.shape)

### Optimise SS Models

#### Pathogens

In [None]:
# raw_SS_params = optimise(raw_SS_train, y_train)
raw_SS_params = {'colsample_bytree': 0.1, 'gamma': 1, 'max_depth': 3, 'n_estimators': 270, 'scale_pos_weight': 1.4273504273504274, 'subsample': 1.0}

In [None]:
# RA_SS_params = optimise(RA_SS_train, y_train)
RA_SS_params = {'colsample_bytree': 0.2, 'gamma': 1, 'max_depth': 4, 'n_estimators': 70, 'scale_pos_weight': 1.4273504273504274, 'subsample': 0.6}

## Fit optimised models

In [None]:
# Fit optimised model on all training data

# Decontam
raw_SS_model = XGBClassifier(**raw_SS_params)
raw_SS_model.fit(X=raw_SS_train, y=y_train)

RA_SS_model = XGBClassifier(**RA_SS_params)
RA_SS_model.fit(X=RA_SS_train, y=y_train)

## Evaluate model

In [None]:
raw_SS_metric = evaluate(raw_SS_model, raw_SS_test, y_test)
RA_SS_metric = evaluate(RA_SS_model, RA_SS_test, y_test)

raw_SS_metric = evaluate(raw_SS_model, raw_SS_test, y_test)
RA_SS_metric = evaluate(RA_SS_model, RA_SS_test, y_test)

metric_df = pd.concat([raw_metric,
                       RA_metric,
                       raw_SS_metric,
                       RA_SS_metric], axis=0)
metric_df.index = ['Raw', 'RA', 'Raw SS', 'RA SS']
display(metric_df.round(3))

### Confidence Intervals (non-parametric boostrap estimates)

Bootstrap with 1001 iterations, 95% CI

In [None]:
def get_percentiles(x, alpha=0.05):
    low = np.percentile(x, alpha / 2 * 100)
    high = np.percentile(x, (1 - alpha / 2) * 100)
    
    return low, high


np.random.seed(66)
from sklearn.utils import resample


def get_confint(model, X_test, y_test, n_iter=1001):
    boot_df = pd.DataFrame({'F1': [0], 'Sensitivity': [0], 'Specificity': [0], 'AUROC' : [0]})
    
    for _ in range(n_iter):
        boot_X, boot_y = resample(X_test, y_test, n_samples=len(y_test), replace=True, stratify=y_test)
        y_pred = model.predict(boot_X)
        y_score = model.predict_proba(boot_X)[:, 1]

        sensitivity, specificity, _ = sensitivity_specificity_support(y_true=boot_y, y_pred=y_pred, average='binary')
        precision, recall, f1, _ = precision_recall_fscore_support(y_true=boot_y, y_pred=y_pred, average='binary')
        auc = roc_auc_score(y_true=boot_y, y_score=y_score)
        temp_df = pd.DataFrame({'F1': [f1], 'Sensitivity': [sensitivity], 
                                'Specificity': [specificity], 'AUROC' : [auc]})
        
        boot_df = pd.concat([boot_df, temp_df], axis=0)
    
    boot_df = boot_df.iloc[1:, :]
    
    confints = [get_percentiles(boot_df[col]) for col in boot_df.columns]
    display(pd.DataFrame(confints, 
                         columns=['2.5%', '97.5%'], 
                         index=boot_df.columns).transpose().round(3))
    

In [None]:
print('Raw:', end='')
get_confint(raw_model, raw_test, y_test)
print('RA:', end='')
get_confint(RA_model, RA_test, y_test)

print('Raw SS:', end='')
get_confint(raw_SS_model, raw_SS_test, y_test)
print('RA SS:', end='')
get_confint(RA_SS_model, RA_SS_test, y_test)

## Interpreting model using SHAP values

### Feature importance
This is a plot of mean absolute SHAP values per feature

### Plot of SHAP values per Feature

In [None]:
import matplotlib.pyplot as plt
explainer_SS = shap.TreeExplainer(raw_SS_model, feature_pertubation='interventional', model_output='probability', data=raw_SS_train)
shap_SS = explainer_SS.shap_values(raw_SS_test)

explainer_raw = shap.TreeExplainer(raw_model, feature_pertubation='interventional', model_output='probability', data=raw_train)
shap_raw = explainer_raw.shap_values(raw_test)

In [None]:
shap.summary_plot(shap_SS, raw_SS_test, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=35)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'SS_shap.png', dpi=1200, format='png', bbox_inches='tight')

In [None]:
shap.summary_plot(shap_raw, raw_test, show=False, plot_size=(4, 5), color_bar_label='Unique k-mer Count', max_display=23)
fig, ax = plt.gcf(), plt.gca()
ax.set_xlabel('SHAP Value')
plt.savefig(results / 'raw_shap.png', dpi=1200, format='png', bbox_inches='tight')

* Features are ranked by importance from top to botttom
* feature values are the kmer counts for each genus
* SHAP values are the average marginal contributions to probability

### Force plot for healthy patient

In [None]:
j = 72
print(f'Actual Classification {y_test[j]}')
print(raw_SS_test.index[j])
shap.force_plot(explainer_SS.expected_value, 
                shap_SS[j,:], 
                raw_SS_test.iloc[j,:],
                show=False,
                matplotlib=True)
plt.savefig(results / 'SS_force_plot.png', dpi=1200, format='png', bbox_inches='tight')

## How much does Escherichia drive predictions?

In [None]:
escherichia_idx = raw_SS_test.columns.get_loc('Escherichia')

In [None]:
y_score = raw_SS_model.predict_proba(raw_SS_test)[:, 1]
old_auc = roc_auc_score(y_true=y_test, y_score=y_score)
new_auc = roc_auc_score(y_true=y_test, y_score=y_score - shap_SS[:, escherichia_idx])
print(f"Before Removing Escherichia = {old_auc}\nAfter Removing Escherichia = {new_auc}")