In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import argparse
from typing import Union
from functools import partial

import numpy as np
import pandas as pd
from tqdm import tqdm

from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest

from survLime import survlime_explainer
from survLime.datasets.load_datasets import Loader


def main(args):
    if args.dataset == "all":
        datasets = ["veterans", "udca", "lung", "pbc"]
    else:
        datasets = [args.dataset]
    if args.model != "all":
        models = [args.model]
    else:
        models = ["cox", "rsf"]
    for model in models:
        args.model = model
        for i in tqdm(range(args.repetitions)):
            for dataset in datasets:
                loader = Loader(dataset_name=dataset)
                x, events, times = loader.load_data()

                train, _, test = loader.preprocess_datasets(x, events, times, random_seed=i)

                if args.model == "cox":
                    model = CoxPHSurvivalAnalysis(alpha=0.0001)
                elif args.model == "rsf":
                    model = RandomSurvivalForest()
                else:
                    raise AssertionError

                model.fit(train[0], train[1])

                times_to_fill = list(set([x[1] for x in train[1]]))
                times_to_fill.sort()
                #H0 = model.cum_baseline_hazard_.y.reshape(len(times_to_fill), 1)

                explainer = survlime_explainer.SurvLimeExplainer(
                    train[0], train[1], model_output_times=model.event_times_
                )

                computation_exp = compute_weights(explainer, test[0], model, num_neighbors=args.num_neigh)
               # save_path = f"/home/carlos.hernandez/PhD/survlime-paper/survLime/computed_weights_csv/exp3/{args.model}_exp_{dataset}_surv_weights_na_rand_seed_{i}.csv"
                computation_exp.to_csv(save_path, index=False)


def compute_weights(
    explainer: survlime_explainer.SurvLimeExplainer,
    x_test:  Union[pd.DataFrame, np.ndarray],
    model: Union[CoxPHSurvivalAnalysis, RandomSurvivalForest],
    num_neighbors: int = 1000
):
    compt_weights = []
    num_pat = num_neighbors
    predict_chf = partial(model.predict_cumulative_hazard_function, return_array=True)
    for test_point in tqdm(x_test.to_numpy()):
        try:
            b, _ = explainer.explain_instance(
                test_point, predict_chf, verbose=False, num_samples=num_pat
            )

            b = [x[0] for x in b]
        except:
            b = [None] * len(test_point)
        compt_weights.append(b)

    return pd.DataFrame(compt_weights, columns=model.feature_names_in_)

In [43]:
loader = Loader(dataset_name='veterans')
x, events, times = loader.load_data()

train, _, test = loader.preprocess_datasets(x, events, times, random_seed=0)
model = CoxPHSurvivalAnalysis(alpha=0.0001)
model = RandomSurvivalForest()
model.fit(train[0], train[1])

times_to_fill = list(set([x[1] for x in train[1]]))
times_to_fill.sort()

explainer = survlime_explainer.SurvLimeExplainer(
    train[0], train[1], model_output_times=model.event_times_
)


compt_weights = []
num_pat = 1000
predict_chf = partial(model.predict_cumulative_hazard_function, return_array=True)
for i, test_point in tqdm(enumerate(test[0].to_numpy())):
    if i==0:
        pass
    else:
        b, _ = explainer.explain_instance(
            test_point, predict_chf, verbose=False, num_samples=num_pat
        )
        break
      

1it [00:00,  1.42it/s]


In [44]:
%matplotlib inline
import plotly.io as pio
pio.renderers.default = "iframe"
explainer.plot()