In [1]:
import pandas as pd, numpy as np, pickle
from interactiontransformer.InteractionTransformer import InteractionTransformer, run_shap, XGBoostSurvival
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from imblearn.ensemble import BalancedRandomForestClassifier
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib as mpl
import shap
import scipy
from sklearn.metrics import roc_auc_score
import warnings
warnings.filterwarnings("ignore")
mpl.rcParams['figure.dpi'] = 300
sns.set(style='white',font_scale=0.5)

# Survival Model from NHANES
See https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html 

In [3]:
X,y = shap.datasets.nhanesi()

In [4]:
X=X.iloc[:,1:]
y=pd.Series(y,index=X.index)

# Train/test split


In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Fit Transformer

In [6]:
transformer=InteractionTransformer(untrained_model=XGBoostSurvival(eta=0.002,max_depth=3,subsample=0.5),max_train_test_samples=100,mode_interaction_extract=int(np.sqrt(X_train.shape[1])),cv_scoring="survival",use_background_data=False,tree_limit=100,num_workers=8,compute_interaction_dask=False) 
transformer.fit(X_train,y_train)



Shap Interaction Size: (100, 18, 18)


<interactiontransformer.InteractionTransformer.InteractionTransformer at 0x13aa438e0>

# Transform Data
Terms may be input into R's coxph or survfit model, alternatively, in python one may use the scikit-survival package: https://nbviewer.jupyter.org/github/sebp/scikit-survival/blob/master/examples/00-introduction.ipynb

In [7]:
X_train2,X_test2=transformer.transform(X_train),transformer.transform(X_test)

In [8]:
transformer.all_interaction_shap_scores

Unnamed: 0,feature_1,feature_2,shap_interaction_score
10,Age,Sex,0.014132
11,Age,Systolic BP,0.002474
0,Age,Diastolic BP,0.002468
4,Age,Sedimentation rate,0.002197
8,Age,Serum Magnesium,0.001609
...,...,...,...
68,Red blood cells,Sex,0.000000
67,Red blood cells,Serum Protein,0.000000
66,Red blood cells,Serum Magnesium,0.000000
65,Red blood cells,Serum Iron,0.000000


In [9]:
X_train2.head()

Unnamed: 0,Age,Diastolic BP,Poverty index,Race,Red blood cells,Sedimentation rate,Serum Albumin,Serum Cholesterol,Serum Iron,Serum Magnesium,...,Systolic BP,TIBC,TS,White blood cells,BMI,Pulse pressure,Age:Sex,Age:Systolic BP,Age:Diastolic BP,Age:Sedimentation rate
1,71.0,78.0,210.0,2.0,77.7,37.0,4.0,298.0,89.0,1.38,...,156.0,331.0,26.9,5.3,32.362572,78.0,142.0,11076.0,5538.0,2627.0
2,74.0,86.0,999.0,2.0,77.7,31.0,3.8,222.0,115.0,1.37,...,170.0,299.0,38.5,8.1,25.388497,84.0,148.0,12580.0,6364.0,2294.0
4,32.0,70.0,183.0,2.0,77.7,18.0,5.0,203.0,192.0,1.35,...,128.0,386.0,49.7,8.1,20.354684,58.0,32.0,4096.0,2240.0,576.0
5,40.0,78.0,297.0,2.0,77.7,24.0,4.0,173.0,121.0,1.71,...,118.0,370.0,32.7,10.7,27.217201,40.0,80.0,4720.0,3120.0,960.0
6,53.0,76.0,461.0,1.0,77.7,2.0,4.3,276.0,135.0,1.74,...,124.0,334.0,40.4,6.0,23.091823,48.0,53.0,6572.0,4028.0,106.0
