In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from helpers import (get_training_observations, 
                     get_training_labels, 
                     get_protein_proportions)

import shap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gc

### Decision Tree

In [6]:
# Load decision tree model
from joblib import load
model_dt = load('../model_joblibs/final_decision_tree.joblib')
x_train = get_training_observations()
gc.collect()

Getting all training observations from 'metagenome_classification.db'...


0

In [38]:
model_dt.classes_

array(['Aerosol (non-saline)', 'Animal corpus', 'Animal proximal gut',
       'Hypersaline (saline)', 'Plant corpus', 'Plant rhizosphere',
       'Plant surface', 'Sediment (non-saline)', 'Sediment (saline)',
       'Soil (non-saline)', 'Subsurface (non-saline)',
       'Surface (non-saline)', 'Surface (saline)', 'Water (non-saline)',
       'Water (saline)'], dtype=object)

In [15]:
explainer = shap.TreeExplainer(model_dt)
shap_values = explainer.shap_values(x_train)
gc.collect()

22481

In [18]:
shap.summary_plot(shap_values, x_train, class_names=model_dt.classes_, show=False)
plt.legend(loc="lower right")
plt.savefig(f'shap_decision_tree/shap_images/overall.png', bbox_inches="tight")
plt.close('all')

In [20]:
for i in range(len(shap_values)):
    cat = model_dt.classes_[i]
    vals = np.abs(shap_values[i]).mean(0)
    feature_importance = pd.DataFrame(list(zip(x_train.columns, vals)), columns=['pfam','feature_importance_vals'])
    feature_importance.sort_values(by=['feature_importance_vals'], ascending=False, inplace=True)
    feature_importance.to_csv(f'shap_decision_tree/shap_data/{cat}.csv', index=False)
    shap.summary_plot(shap_values[i], x_train, class_names=model_dt.classes_, show=False)
    plt.savefig(f'shap_decision_tree/shap_images/{cat}.png', bbox_inches="tight")
    plt.close('all')
    gc.collect()

### XGBoost

In [7]:
# Load decision tree model
from joblib import load
model_xgb = load('../model_joblibs/xgb_delaney.joblib')
gc.collect()

16

In [21]:
explainer = shap.TreeExplainer(model_xgb)
shap_values = explainer.shap_values(x_train)
gc.collect()

pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.


174511

In [14]:
shap.summary_plot(shap_values, x_train, class_names=model_dt.classes_, show=False)
plt.legend(loc="lower right")
plt.savefig(f'shap_xgboost/shap_images/overall.png', bbox_inches="tight")
plt.close('all')
gc.collect()

344

In [23]:
for i in range(len(shap_values)):
    cat = model_dt.classes_[i]
    vals = np.abs(shap_values[i]).mean(0)
    feature_importance = pd.DataFrame(list(zip(x_train.columns, vals)), columns=['pfam','feature_importance_vals'])
    feature_importance.sort_values(by=['feature_importance_vals'], ascending=False, inplace=True)
    feature_importance.to_csv(f'shap_xgboost/shap_data/{cat}.csv', index=False)
    shap.summary_plot(shap_values[i], x_train, class_names=model_dt.classes_, show=False)
    plt.savefig(f'shap_xgboost/shap_images/{cat}.png', bbox_inches="tight")
    plt.close('all')
    gc.collect()