# Code

**Team:** Kamil Grudzień, Krystian Sztenderski, Jakub Bednarz.

**Project:** (V) Time-dependent explanations of neural networks for survival analysis.

## Introduction

**What was done?** We've:

1. Learned how to use the `pycox` package, in particular the DeepHit survival analysis model.
2. Learned how to use the `sksurv` package.
3. Wrapped various models (Cox proportional hazards, Random Survival Forest and DeepHit) into an uniform interface to make training and evaluation easier.
4. Read the [SurvSHAP(t) paper](https://arxiv.org/abs/2208.11080) and the [implementation provided by the authors](https://github.com/MI2DataLab/survshap).
5. Learned how to use various NN-specific explainability techniques from the `captum` library - in particular, we've adapted the `DeepLift`, `DeepLiftShap` and `IntegratedGradients` methods to provide explanations for DeepHit analogous to the ones given by SurvSHAP(t).
6. Wrapped all of them into a single interface to compare them "on equal ground."
7. Replicated experiment from the SurvSHAP(t) paper to verify we're using the library correctly. Beyond that, we've also trained and evaluated DeepHit on the same dataset, obtained ground-truth explanations with SurvSHAP(t) and ran DeepLift, DeepLiftShap and Integrated Gradients to see how NN-specific explanations compare with SurvSHAP(t).
8. We've also performed a preliminary experiment on a real-world dataset (METABRIC) in a similar fashion to one described in (7).
9. For experiments in (7) and (8), we've made a "coarse-grained analysis of the results.", i.e. we've made plots of the SHAP values at given time points and evaluated them qualitatively.

**What are the difficulties?**

1. The NN-specific explanations do not *seem* to correlate at all with the ground-truth, so a further analysis would be needed.
2. Although we evaluate the models quantitatively (via concordance index,) we still don't exactly know if the models we've trained for these dataset give "reasonable results". Of course, if the model does not perform well, the explanations given would be meaningless, so it would be wise to eliminate that cause of uncertainty.

**What will be done next?**

1. Adding quantitative metrics for comparing the explanations given by SurvSHAP(t) and other methods.
2. Performing deeper analysis of the trained models and the explanations.
3. (Possibly) Testing other NNs for survival analysis than DeepHit.
4. Adding measurement of execution time.

In [1]:
import numpy as np
import pandas as pd
from ruamel.yaml import safe_load
import scipy.integrate
import scipy.optimize
from scipy.interpolate import interp1d

from warnings import catch_warnings, simplefilter
from dataclasses import dataclass
from typing import Callable, Optional, Any
from tqdm import tqdm

import plotly.graph_objects as go
import plotly.offline as py
import plotly.io as pio
pio.renderers.default = "jpeg"

import torch
import torch.nn as nn
import torch.optim as optim
import torchtuples as tt

import sksurv
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import get_x_y
from sksurv.functions import StepFunction
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 
from sklearn.model_selection import train_test_split

from pycox.datasets import metabric
from pycox.models import DeepHitSingle
from pycox.evaluation import EvalSurv
from pycox.models import DeepHitSingle

from survshap import SurvivalModelExplainer, ModelSurvSHAP, PredictSurvSHAP

from captum.attr import DeepLift, IntegratedGradients, DeepLiftShap

## DeepHit adapter

In [2]:
@dataclass
class TrainConf:
    optimizer_fn: Callable[[nn.Module], optim.Optimizer] \
        = lambda net: optim.Adam(net.parameters(), lr=1e-3)
    device: Optional[torch.device] = None
    batch_size: int = 256
    epochs: int = 1
    callbacks: Any = None
    verbose: bool = False
    num_workers: int = 0
    shuffle: bool = False
    metrics: Any = None
    val_data: Optional[torch.Tensor] = None
    val_batch_size: int = 8224


class DeepHitSingle_:
    """A sksurv-like wrapper around pycox's DeepHitSingle."""

    def __init__(self, net: nn.Module, timestamps=None, alpha=0.2, sigma=0.1):
        """Create an instance.
        :param net: NN at the core of DeepHit. The number of output features
         is the number of cuts/timestamps.
        :param timestamps: (Optional) Predefined timestamps/steps to use. 
        Number of them must be equal to the dimensionality of net's output 
        space. If not provided, default cuts are used (see pycox docs for more 
        details.)
        :param alpha, sigma: Parameters for the pycox.models.DeepHitSingle 
        class."""

        self.net = net
        self.timestamps = None
        self.alpha = alpha
        self.sigma = sigma
    
    def fit(self, X, y, conf: Optional[TrainConf] = None):
        if conf is None:
            conf = TrainConf()

        # Get the timestamps/cuts
        if self.timestamps is not None:
            cuts = self.timestamps
        else:
            # If the cuts are unspecified, we guess it by dry-running the net 
            # and checking the dimensionality of the output space.
            with torch.no_grad():
                self.net.eval()
                num_features = X.shape[1]
                dummy = torch.empty((1, num_features), dtype=torch.float32)
                res = self.net(dummy)[0]
                self.net.train()
            cuts = res.shape[0]

        print(f"cuts: {cuts}")
        
        self.tf = DeepHitSingle.label_transform(cuts)

        # Encode X and y to a form acceptable for DeepHitSingle.
        def encode(X, y, fit=False):
            input = torch.tensor(np.asarray(X), dtype=torch.float32)
            if fit:
                self.tf.fit(y["duration"], y["event"])
            durations, events = self.tf.transform(y["duration"], y["event"])
            durations = torch.tensor(durations, dtype=torch.int64)
            events = torch.tensor(events, dtype=torch.float32)
            target = (durations, events)
            return input, target
        
        input, target = encode(X, y, fit=True)
            
        self._model = DeepHitSingle(
            net=self.net,
            optimizer=conf.optimizer_fn(self.net),
            device=conf.device,
            duration_index=self.tf.cuts,
            alpha=self.alpha,
            sigma=self.sigma,
        )

        if conf.val_data is not None:
            val_X, val_y = conf.val_data
            val_data = encode(val_X, val_y)
        else:
            val_data = None

        self._log = self._model.fit(input, target, conf.batch_size, conf.epochs,
            conf.callbacks, conf.verbose, conf.num_workers, conf.shuffle,
            conf.metrics, val_data, conf.val_batch_size)
        
        self.event_times_ = self.tf.cuts
        
        return self
    
    def predict_surv_df(self, X):
        X = torch.tensor(np.asarray(X), dtype=torch.float32)
        return self._model.predict_surv_df(X).astype("float32")
    
    def predict_survival_function(self, X, return_array=False):
        """Predict survival function. See sksurv models for more details."""

        surv_df = self.predict_surv_df(X)

        event_times_ = surv_df.index.values
        sf_values = surv_df.T.values

        if return_array:
            return sf_values
        else:
            return np.array([
                StepFunction(event_times_, values)
                for values in sf_values
            ])

    def predict_cumulative_hazard_function(self, X, return_array=False):
        raise NotImplementedError
    
    def score(self, X, y):
        surv = self.predict_surv_df(X)

        with catch_warnings():
            simplefilter("ignore")
            eval = EvalSurv(
                surv=surv,
                durations=y["duration"],
                events=y["event"],
                censor_surv="km",
            )
            return eval.concordance_td(method="antolini")

## Attribution methods

### SurvSHAP(t) Adapter

In [3]:
class SurvShapExplainer:
    """A bit more shap-esque Explainer wrapper for SurvSHAP."""

    def __init__(self, model, data=None, y=None, calculation_method="kernel", aggregation_method="integral", path="average", B=25, random_state=42, pbar=False):
        self.model = model
        self.calculation_method = calculation_method
        self.aggregation_method = aggregation_method
        self.path = path
        self.B = B
        self.random_state = random_state
        self.exp = SurvivalModelExplainer(self.model, data=data, y=y)
        self.pbar = pbar

    def __call__(self, observations: pd.DataFrame, timestamps=None) -> pd.DataFrame:
        """Predict SHAP values for a number of observations. In this case, we 
        deal with survival functions in the form of StepFunction, so likewise 
        the output SHAP values will be step functions.
        :param observations: Dataframe with shape (num_obs, num_features).
        :return: A dataframe with shape (num_obs, num_features), where each 
        "cell" contains a StepFunction being the SHAP attribution for a given 
        observation and a given value."""

        if timestamps is None:
            timestamps = self.model.event_times_

        skip = ["variable_str", "variable_name", "variable_value", "B", "aggregated_change", "index"]

        all_results = []
        
        idx_seq = range(len(observations))
        if self.pbar:
            idx_seq = tqdm(idx_seq)
        
        for idx in idx_seq:
            obs = observations.iloc[[idx]]
            shap = PredictSurvSHAP(
                calculation_method=self.calculation_method,
                aggregation_method=self.aggregation_method,
                path=self.path,
                B=self.B,
                random_state=self.random_state,
            )

            shap.fit(self.exp, obs, timestamps)
            obs_df = shap.result
            obs_df.insert(len(skip)-1, "index", idx)
            all_results.append(obs_df)
        
        res_df = pd.concat(all_results)
            
        g = res_df.groupby(by="variable_name")

        var_attr_values = {}
        for var in g.groups:
            grp: pd.DataFrame = g.get_group(var)
            grp = grp.sort_values(by=["index"])
            attr_values = grp.iloc[:,len(skip):].values
            var_attr_values[var] = [
                StepFunction(timestamps, attr_values_)
                for attr_values_ in attr_values
            ]
            
        res_df = pd.DataFrame(var_attr_values)
        res_df = res_df.set_index(observations.index, drop=True)
        return res_df

### DeepLift Adapter

In [4]:
class DeepLiftExplainer:
    """A shap-esque wrapper for DeepLift. See SurvShapExplainer for more
     details."""
     
    def __init__(self, model: DeepHitSingle_, data=None, y=None):
        self.model = model
        if data is not None:
            data = torch.tensor(data.values, dtype=torch.float32)
            self.baselines = data.mean(dim=0)
        else:
            self.baselines = None
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines.broadcast_to(inputs.shape)
        else:
            baselines = inputs.mean(dim=0).broadcast_to(inputs.shape)

        deep_lift = DeepLift(self.model.net)

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = deep_lift.attribute(inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=np.cumsum(attr_values[obs_idx,var_idx]),
                ) for obs_idx in range(len(X))
            ]
        
        return pd.DataFrame(var_attr_values)

### DeepLiftShap Adapter

In [5]:
class DeepLiftShapExplainer:
    """A shap-esque wrapper for DeepLiftShap. See SurvShapExplainer for more
     details."""
     
    def __init__(self, model: DeepHitSingle_, data=None, y=None):
        self.model = model
        if data is not None:
            self.baselines = torch.tensor(data.values, dtype=torch.float32)
        else:
            self.baselines = None
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines
        else:
            baselines = inputs

        deep_lift_shap = DeepLiftShap(self.model.net)

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = deep_lift_shap.attribute(
                    inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=np.cumsum(attr_values[obs_idx,var_idx]),
                ) for obs_idx in range(len(X))
            ]
        
        return pd.DataFrame(var_attr_values)

### Integrated Gradients (IG) Adapter

In [6]:
class IGExplainer:
    def __init__(self, model: DeepHitSingle_, data=None, y=None):
        self.model = model
        if data is not None:
            data = torch.tensor(data.values, dtype=torch.float32)
            self.baselines = data.mean(dim=0)
        else:
            self.baselines = None
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines.broadcast_to(inputs.shape)
        else:
            baselines = inputs.mean(dim=0).broadcast_to(inputs.shape)

        ig = IntegratedGradients(self.model.net)

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = ig.attribute(inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=np.cumsum(attr_values[obs_idx,var_idx]),
                ) for obs_idx in range(len(X))
            ]
        
        return pd.DataFrame(var_attr_values)

## Checking if the SurvSHAP adapter works fine

### Data - `exp1_data.csv`

In [7]:
def exp1_csv():
    df = pd.read_csv("data/exp1_data_complex.csv")
    df = df.rename(columns={"time": "duration"})
    return df

exp1_df = exp1_csv()

In [8]:
df = exp1_df

X, y = get_x_y(df, attr_labels=["event", "duration"], pos_label=1)
train_X = test_X = X
train_y = test_y = y

# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_X, train_y = get_x_y(train_df, 
#     attr_labels=["event", "duration"], pos_label=1)
# test_X, test_y = get_x_y(test_df, 
#     attr_labels=["event", "duration"], pos_label=1)

In [17]:
test_y.shape

(1000,)

In [14]:
a = np.zeros((2, 2), np.dtype([('event', '?'), ('duration', '<f8')]))

In [15]:
a

array([[(False, 0.), (False, 0.)],
       [(False, 0.), (False, 0.)]],
      dtype=[('event', '?'), ('duration', '<f8')])

In [9]:
def eval_model(model, name):
    print(f"[{name}] Train score: {model.score(train_X, train_y)}")
    print(f"[{name}] Test score: {model.score(test_X, test_y)}")

### Linear model (`CoxPHSurvivalAnalysis`)

In [10]:
cph = CoxPHSurvivalAnalysis()
cph = cph.fit(train_X, train_y)
eval_model(cph, "cph")

[cph] Train score: 0.622925665249838
[cph] Test score: 0.622925665249838


### Random survival forest

In [11]:
rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)

rsf = rsf.fit(train_X, train_y)
eval_model(rsf, "rsf")

[rsf] Train score: 0.8095397887950072
[rsf] Test score: 0.8095397887950072


### DeepHit

In [12]:
net = tt.practical.MLPVanilla(train_X.shape[1], [32, 32], 91, True, 0.1)
deephit_ = DeepHitSingle_(net, alpha=0.2, sigma=0.1)

x_train, x_val, y_train, y_val = \
    train_test_split(train_X, train_y, test_size=0.2, random_state=42)

conf = TrainConf(
    optimizer_fn=lambda net: torch.optim.Adam(net.parameters(), lr=1e-2),
    epochs=100,
    callbacks=[tt.callbacks.EarlyStopping()],
    batch_size=256,
    val_data=(x_val, y_val),
    verbose=True,
)

deephit_ = deephit_.fit(x_train, y_train, conf=conf)
eval_model(deephit_, "deephit")

cuts: 91
0:	[0s / 0s],		train_loss: 0.9383,	val_loss: 1.1976
1:	[0s / 0s],		train_loss: 0.8737,	val_loss: 1.0269
2:	[0s / 0s],		train_loss: 0.8639,	val_loss: 0.9942
3:	[0s / 0s],		train_loss: 0.8352,	val_loss: 0.9838
4:	[0s / 0s],		train_loss: 0.8302,	val_loss: 0.9688
5:	[0s / 0s],		train_loss: 0.8099,	val_loss: 0.9634
6:	[0s / 0s],		train_loss: 0.7988,	val_loss: 0.9671
7:	[0s / 0s],		train_loss: 0.7831,	val_loss: 0.9753
8:	[0s / 0s],		train_loss: 0.7832,	val_loss: 0.9809
9:	[0s / 0s],		train_loss: 0.7744,	val_loss: 0.9830
10:	[0s / 0s],		train_loss: 0.7620,	val_loss: 0.9803
11:	[0s / 0s],		train_loss: 0.7538,	val_loss: 0.9799
12:	[0s / 0s],		train_loss: 0.7358,	val_loss: 0.9779
13:	[0s / 0s],		train_loss: 0.7419,	val_loss: 0.9810
14:	[0s / 0s],		train_loss: 0.7320,	val_loss: 0.9835
15:	[0s / 0s],		train_loss: 0.7246,	val_loss: 0.9842
[deephit] Train score: 0.6401972087900741
[deephit] Test score: 0.6401972087900741


In [31]:
print(DeepHitSingle.label_transform(91).cuts)

None


In [20]:
test_y[0]

(False, 1.37152349)

In [33]:
train_y['event']

array([False, False, False, False,  True,  True,  True, False, False,
        True, False,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True, False, False,
       False,  True,  True, False, False,  True, False, False, False,
       False, False, False,  True, False,  True, False,  True, False,
        True, False, False,  True, False,  True,  True, False,  True,
        True,  True,  True, False,  True,  True,  True,  True, False,
        True,  True,  True, False,  True,  True,  True,  True,  True,
       False,  True,  True,  True,  True,  True,  True, False, False,
       False, False,  True, False,  True, False,  True,  True,  True,
        True, False,  True, False,  True,  True, False,  True, False,
        True, False,  True, False, False,  True,  True, False,  True,
       False, False,  True,  True, False,  True,  True, False, False,
        True, False,  True,  True,  True,  True, False, False,  True,
        True, False,

In [25]:
deephit_.net

MLPVanilla(
  (net): Sequential(
    (0): DenseVanillaBlock(
      (linear): Linear(in_features=5, out_features=32, bias=True)
      (activation): ReLU()
      (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): DenseVanillaBlock(
      (linear): Linear(in_features=32, out_features=32, bias=True)
      (activation): ReLU()
      (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (2): Linear(in_features=32, out_features=91, bias=True)
  )
)

### Plots 

In [18]:
def plot_expl(expl):
    fig = go.Figure()
    for var in expl.index:
        shap_f = expl[var]
        fig.add_trace(go.Scatter(
            x=shap_f.x, y=shap_f.y,
            mode="lines",
            line=dict(shape="hv"),
            name=var,
        ))
    return fig

In [None]:
cph_expl = SurvShapExplainer(cph, test_X, test_y)
cph_expl0 = cph_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(cph_expl0)

In [None]:
rsf_expl = SurvShapExplainer(rsf, test_X, test_y)
rsf_expl0 = rsf_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(rsf_expl0)

In [None]:
dh_expl = SurvShapExplainer(deephit_, test_X, test_y)
dh_expl0 = dh_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(dh_expl0)

In [44]:
test_y[:10]

array([(False,  1.37152349), (False,  1.03176018), (False, 13.38743933),
       (False,  7.87972083), ( True,  1.93900302), ( True,  2.57213241),
       ( True,  1.15393586), (False,  4.75450825), (False,  2.89922113),
       ( True, 12.35290962)], dtype=[('event', '?'), ('duration', '<f8')])

In [35]:
test_X.iloc[[690]]

Unnamed: 0,x1,x2,x3,x4,x5
690,1,1,11.622619,26.222602,-0.441441


In [None]:
lift_expl = DeepLiftExplainer(deephit_, test_X, test_y)
lift_expl0 = lift_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(lift_expl0)

In [None]:
dls_expl = DeepLiftShapExplainer(deephit_, test_X, test_y)
dls_expl0 = dls_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(dls_expl0)

In [None]:
ig_expl = IGExplainer(deephit_, test_X, test_y)
ig_expl0 = ig_expl(test_X.iloc[[690]]).iloc[0]
plot_expl(ig_expl0)

## Real-world case: METABRIC dataset

Note: data and training protocol taken from [example pycox notebook](https://nbviewer.org/github/havakv/pycox/blob/master/examples/deephit.ipynb).

In [8]:
np.random.seed(1234)
_ = torch.manual_seed(123)

### Data

In [9]:
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

Dataset 'metabric' not locally available. Downloading...
Done


In [10]:
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

def pycox_get_x_y(df):
    values = x_mapper.transform(df).astype("float32")
    X = pd.DataFrame(values, columns=x_mapper.transformed_names_)
    _, y = get_x_y(df, ["event", "duration"], 1)
    return X, y

In [11]:
x_mapper.fit(df_train)

x_train, y_train = pycox_get_x_y(df_train)
x_val, y_val = pycox_get_x_y(df_val)
x_test, y_test = pycox_get_x_y(df_test)

In [12]:
net = tt.practical.MLPVanilla(x_train.shape[1], [32, 32], 91, True, 0.1)
deephit_ = DeepHitSingle_(net, alpha=0.2, sigma=0.1)

In [13]:
conf = TrainConf(
    optimizer_fn=lambda net: torch.optim.Adam(net.parameters(), lr=1e-2),
    epochs=100,
    callbacks=[tt.callbacks.EarlyStopping()],
    batch_size=256,
    val_data=(x_val, y_val),
    verbose=True,
)

deephit_ = deephit_.fit(x_train, y_train, conf=conf)

0:	[0s / 0s],		train_loss: 0.8914,	val_loss: 0.7831
1:	[0s / 0s],		train_loss: 0.8337,	val_loss: 0.7821
2:	[0s / 0s],		train_loss: 0.8177,	val_loss: 0.7813
3:	[0s / 0s],		train_loss: 0.8022,	val_loss: 0.7857
4:	[0s / 0s],		train_loss: 0.7879,	val_loss: 0.7898
5:	[0s / 0s],		train_loss: 0.7792,	val_loss: 0.7930
6:	[0s / 0s],		train_loss: 0.7725,	val_loss: 0.7970
7:	[0s / 0s],		train_loss: 0.7632,	val_loss: 0.7998
8:	[0s / 0s],		train_loss: 0.7524,	val_loss: 0.8071
9:	[0s / 0s],		train_loss: 0.7495,	val_loss: 0.8192
10:	[0s / 0s],		train_loss: 0.7335,	val_loss: 0.8220
11:	[0s / 0s],		train_loss: 0.7284,	val_loss: 0.8210
12:	[0s / 0s],		train_loss: 0.7248,	val_loss: 0.8239


In [14]:
deephit_.score(x_test, y_test)

0.685340624867523

### Linear model

In [15]:
cph = CoxPHSurvivalAnalysis()
cph.fit(x_train, y_train)
cph.score(x_test, y_test)

0.6503878926618339

In [16]:
rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)
rsf = rsf.fit(x_train, y_train)
rsf.score(x_test, y_test)

0.6550722794522871

### Explanations

In [19]:
cph_expl = SurvShapExplainer(cph, x_train, y_train)
cph_expl0 = cph_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(cph_expl0)

KeyboardInterrupt: 

In [None]:
rsf_expl = SurvShapExplainer(rsf, x_train, y_train)
rsf_expl0 = rsf_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(rsf_expl0)

In [None]:
dh_expl = SurvShapExplainer(deephit_, x_train, y_train)
dh_expl0 = dh_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(dh_expl0)

In [21]:
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train)
lift_expl0 = lift_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(lift_expl0)

ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


Figure({
    'data': [{'line': {'shape': 'hv'},
              'mode': 'lines',
              'name': 'x0',
              'type': 'scatter',
              'x': array([  0.        ,   3.9466668 ,   7.8933336 ,  11.84000041,  15.78666721,
                           19.73333401,  23.68000081,  27.62666762,  31.57333442,  35.52000122,
                           39.46666802,  43.41333483,  47.36000163,  51.30666843,  55.25333523,
                           59.20000203,  63.14666884,  67.09333564,  71.04000244,  74.98666924,
                           78.93333605,  82.88000285,  86.82666965,  90.77333645,  94.72000326,
                           98.66667006, 102.61333686, 106.56000366, 110.50667046, 114.45333727,
                          118.40000407, 122.34667087, 126.29333767, 130.24000448, 134.18667128,
                          138.13333808, 142.08000488, 146.02667169, 149.97333849, 153.92000529,
                          157.86667209, 161.81333889, 165.7600057 , 169.7066725 , 173.653339

In [None]:
dls_expl = DeepLiftShapExplainer(deephit_, x_train, y_train)
dls_expl0 = dls_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(dls_expl0)

In [None]:
ig_expl = IGExplainer(deephit_, x_train, y_train)
ig_expl0 = ig_expl(x_test.iloc[[0]]).iloc[0]
plot_expl(ig_expl0)