# **This notebook aims to study interpretability methods for differents Survival ML**

# Librairies import

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Displays for output
from colorama import Fore, Style
from IPython.display import clear_output
from tqdm import tqdm

# Standard ML import

import sklearn 
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.model_selection import learning_curve
from sklearn.svm import SVC
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.base import BaseEstimator, TransformerMixin
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
import optuna
import shap
import lime
from itertools import product
from tools import concordance_censored, generate_param_grid, get_param_combinations, count_combinations, generate_random_numbers, random_number_dict


# Survival Analysis tools
import sksurv
#import survlimepy
from sksurv.ensemble import GradientBoostingSurvivalAnalysis, RandomSurvivalForest
from sksurv.svm import FastKernelSurvivalSVM
from sksurv.metrics import concordance_index_ipcw, cumulative_dynamic_auc, integrated_brier_score
#from survlimepy import SurvLimeExplainer

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


## **Pipeline for data's preprocessing**

In [2]:
def create_pipeline(dataframe):

    # Separate categorical and numerical columns
    categorical_columns = dataframe.select_dtypes(include=['object']).columns.tolist()
    numeric_columns = dataframe.select_dtypes(exclude=['object']).columns.tolist()

    # Create transformers
    categorical_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='most_frequent')),
        ('onehot', OneHotEncoder(handle_unknown='ignore'))
    ])
    numeric_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='mean'))
    ])

    # Combine transformers using ColumnTransformer
    preprocessor = ColumnTransformer(transformers=[
        ('cat', categorical_transformer, categorical_columns),
        ('num', numeric_transformer, numeric_columns)
    ])

    # Create the complete pipeline
    pipeline = Pipeline(steps=[('preprocessor', preprocessor)])
    
    return pipeline

In [3]:
def apply_pipeline(pipeline, df):
    transformed_data = pipeline.named_steps['preprocessor'].fit_transform(df)
    
    # Get column names from OneHotEncoder
    categorical_columns = pipeline.named_steps['preprocessor'].named_transformers_['cat'].named_steps['onehot'].get_feature_names_out(input_features=df.select_dtypes(include=['object']).columns.tolist())
    column_names = list(categorical_columns) + df.select_dtypes(exclude=['object']).columns.tolist()
    
    # Convert transformed data to DataFrame with column names
    transformed_df = pd.DataFrame(transformed_data, columns=column_names)
    
    return transformed_df

In [4]:
def encoded_data(df, categorial_columns):
    encoder = OneHotEncoder()
    encoder.fit(df[categorial_columns])

    cat_encoded = encoder.fit_transform(df[categorial_columns]).toarray()
    new_columns = encoder.get_feature_names_out(categorial_columns)
    cat_encoded_df = pd.DataFrame(cat_encoded, columns=new_columns)
    df_encoded = df.drop(columns=categorial_columns).join(cat_encoded_df)

    return df_encoded

# **Creation of Model's Class**

This part aims to create a class Model that allows to fit, predict, optimize, score and interprete GB, RF and SVM models for Survival Analysis.

In [3]:
# Creation of a class Scorer 
class Scorer:
    @staticmethod
    def concordance_censored(estimator, x_test, y_test, y_train = None):
        return concordance_censored(estimator,x_test, y_test)

    @staticmethod
    def concordance_index_ipcw(estimator,x_test , y_test, y_train):
        return concordance_index_ipcw(y_train, y_test, estimator.predict(x_test))[0]

    @staticmethod
    def ibs(estimator, x_test , y_test, y_train):
        if isinstance(estimator, FastKernelSurvivalSVM):
            return 'no ibs score for SSVM'
        else :
            times = list(set(np.percentile(y_test['time'], np.linspace(10, 90, 15))))
            survs = estimator.predict_survival_function(x_test)
            preds = np.asarray([[fn(t) for t in times] for fn in survs])
            return -integrated_brier_score(y_test, y_test, preds, times)     # -ibs because we want to maximise the score to have the best model with optuna

    @staticmethod
    def cumulative_dynamic_auc(estimator, x_test, y_test, y_train):
        times = np.delete(sorted(y_test['time']),-1)
        return cumulative_dynamic_auc(y_train, y_test, estimator.predict(x_test), times)[1]
    
    @staticmethod
    def accuracy(estimator, x_test, y_test, y_train = None):
        return accuracy_score(y_test, estimator.predict(x_test))
    
    @staticmethod
    def roc_auc(estimator, x_test, y_test, y_train = None):
        return roc_auc_score(y_test, estimator.predict(x_test))


In [2]:
class Model:

    """
---------------   Initialisation and basic function   ---------------
    """

    def __init__(self, classifier):
        self.model = classifier
        self.random_state = 42
        self.params = {'random_state': self.random_state}
        self.best_params = {}
        self.event = 'event'
        self.time = 'time'
        if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, RandomSurvivalForest) or isinstance(self.model, FastKernelSurvivalSVM):
            self.scorer = 'concordance_censored'
        elif isinstance(self.model, GradientBoostingClassifier) or isinstance(self.model, RandomForestClassifier) or isinstance(self.model, SVC):
            self.scorer = 'roc_auc'

    def get_params(self):
        self.params['random_state'] = self.random_state
        return self.params

    def fit(self, x, Y):
        if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, RandomSurvivalForest) or isinstance(self.model, FastKernelSurvivalSVM):
            event, time = Y.dtype.names
            self.event = event
            self.time = time
        self.params['random_state'] = self.random_state
        self.model.set_params(**self.params)
        self.model.fit(x, Y)
    
    def predict(self, x):
        return self.model.predict(x)    
    
    def score(self, x_test, Y_test, Y_train = None, metrics = None):
        if metrics == None :
            metrics = [self.scorer]
        estimator = self.model
        scores = {}
        for metric in metrics:
            if hasattr(Scorer, metric):
                scores[metric] = getattr(Scorer, metric)(estimator, x_test, Y_test, Y_train)
            else:
                raise ValueError(f"Metric '{metric}' is not a valid scoring metric./n Metric possible are : concordance_index_censured, concordance_index_ipcw, cumulative_dynamic_auc and ibs")
        return scores



    """
---------------   Cross-validation score of the model   ---------------
    """


    def k_fold_cross_validation(self, x, Y, k=5):
        """
        Performs k-fold cross-validation for a given model and dataset.

        Parameters:
            model: The machine learning model to evaluate.
            X (numpy.ndarray): The feature matrix.
            y (numpy.ndarray): The target vector.
            k (int): Number of folds for cross-validation.

        Returns:
            float: The average accuracy across all folds.
        """
        n = x.shape[0]
        fold_size = n // k
        scores = []

        for i in range(k):
            # Splitting data into training and validation sets
            validation_X = pd.DataFrame(x[i * fold_size: (i + 1) * fold_size], columns=x.columns)
            validation_y = Y[i * fold_size: (i + 1) * fold_size]
            train_X = pd.DataFrame(np.concatenate([x[:i * fold_size], x[(i + 1) * fold_size:]]), columns=x.columns)
            train_y = np.concatenate([Y[:i * fold_size], Y[(i + 1) * fold_size:]])

            # Fitting the model
            self.fit(train_X, train_y)

            # Calculating accuracy
            score = self.score(validation_X, validation_y, train_y, metrics = [self.scorer])[self.scorer]
            scores.append(score)

        # Returning the average accuracy
        return sum(scores) / k

    """
---------------   Optimization of the model   ---------------
    """

    def optimize_with_optuna(self, x, Y, n_trials = 50, plot = False):
        list_score = []
        def objective(trial):
            # Define search space for hyperparameters
            params = {}
            if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, GradientBoostingClassifier):
                params['learning_rate'] = trial.suggest_loguniform('learning_rate', 0.001, 0.1)
                params['n_estimators'] = trial.suggest_int('n_estimators', 50, 400)
                params['max_depth'] = trial.suggest_int('max_depth', 3, 10)
                params['subsample'] = trial.suggest_float('subsample', 0.5, 1.0)
                params['min_samples_split'] = trial.suggest_int('min_samples_split', 2, 20)
                params['min_samples_leaf'] = trial.suggest_int('min_samples_leaf', 1, 20)
            elif isinstance(self.model, RandomSurvivalForest) or isinstance(self.model, RandomForestClassifier):
                params['n_estimators'] = trial.suggest_int('n_estimators', 50, 400)
                params['max_depth'] = trial.suggest_int('max_depth', 3, 10)
                params['max_samples'] = trial.suggest_float('max_samples', 0.1, 1.0)
                params['min_samples_split'] = trial.suggest_int('min_samples_split', 2, 20)
                params['min_samples_leaf'] = trial.suggest_int('min_samples_leaf', 1, 20)
            elif isinstance(self.model, SVC):
                params['max_iter'] = 1000
                params['C'] = trial.suggest_int('C', 1e-5, 1e5)
                params['degree'] = trial.suggest_int('degree', 2, 5)
                params['gamma'] = trial.suggest_float('gamma', 1e-5, 1e3,log=True)
                params['kernel'] = trial.suggest_categorical('kernel', ['linear','poly'])
            elif isinstance(self.model, FastKernelSurvivalSVM):
                params['max_iter'] = 1000
                params['alpha'] = trial.suggest_float('alpha', 0.01, 100)
                params['degree'] = trial.suggest_int('degree', 2, 5)
                params['gamma'] = trial.suggest_float('gamma', 1e-5, 1e3,log=True)
                params['kernel'] = trial.suggest_categorical('kernel', ['linear','poly'])

            # Initialize model with hyperparameters
            if isinstance(self.model, SVC):
                self.model = self.model.set_params(**params, random_state = self.random_state, probability = True)
            else : 
                self.model = self.model.set_params(**params, random_state = self.random_state)

            score = self.k_fold_cross_validation(x, Y, k=5)
            if len(list_score)==0:
                list_score.append(score)
            else : 
                list_score.append(max(score,max(list_score)))

            return score
        
        # Create Optuna study object
        study = optuna.create_study(direction='maximize')

        # Run optimization
        study.optimize(objective, n_trials=n_trials)

        # Access best hyperparameters
        best_params = study.best_params
        self.model.set_params(**best_params)
        self.params = best_params
        
        # Displays of best_params in the os
        clear_output(wait=True)
        
        print(f'Best hyperparameters with optuna : {best_params}')

        if plot:
            plt.plot(list_score)
            plt.xlabel('Numbers of iterations')
            plt.ylabel('Model score')
            plt.title('Evolution of the model score during optuna optimisation')
            plt.show()
        
        return best_params, list_score
    
    def optimize(self, x, Y, num_samples = 3, n_trials = 50, plot = False):
        # Define the gridsearch space
        optimal_params, list_score = self.optimize_with_optuna(x,Y, n_trials)
        param_grid = generate_param_grid(optimal_params, num_samples = num_samples)
        if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, RandomSurvivalForest):
            param_grid['min_samples_leaf'][param_grid['min_samples_leaf'] != 1]
            param_grid['min_samples_split'][param_grid['min_samples_split'] != 1]
                
        # Initialize the model
        best_score = 0
        best_params = None
        nb_comb = count_combinations(param_grid)
        progress_bar = tqdm(total=nb_comb, desc="Progress of GridSearchCV")

        for params in get_param_combinations(param_grid) : 
            # Create model instance with current hyperparameters
            self.model = self.model.set_params(**params, random_state = self.random_state) 

            if self.k_fold_cross_validation(x, Y, k=5) > best_score:
                best_score = self.k_fold_cross_validation(x, Y, k=5)
                best_params = params
                
            progress_bar.update(1)

        self.model.set_params(**best_params)
        self.params = best_params
        self.best_params = best_params
        self.params['random_state'] = self.random_state
        progress_bar.close()

        if plot:
            plt.plot(list_score)
            plt.xlabel('Numbers of iterations')
            plt.ylabel('Model score')
            plt.title('Evolution of the model score during optuna optimisation')
            plt.show()

        # Displays of best_params in the os
        print(f'Best hyperparameters with optuna - GridSearch : {Fore.BLUE}{best_params}{Style.RESET_ALL} \nwith a score: {Fore.BLUE}{best_score}{Style.RESET_ALL}; and the scorer: {Fore.BLUE}{self.scorer}{Style.RESET_ALL}')


    """
---------------   Interpretability methods   ---------------
    """

    def get_interpretability_methods(self, x_train, x_test, Y_train, Y_test, feature = None, index = 0, plot = False):
        interpretability_methods = {
            'SHAP': self.get_shap_values(x_train, x_test, feature, plot),
            'LIME': self.get_lime_explanation(x_train, x_test, index),
            'PI': self.get_pi_values(x_test, Y_test, Y_train)
        }
        return interpretability_methods

    def get_shap_values(self, x_train, x_test, feature = None, plot = False):
        if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, RandomSurvivalForest) or isinstance(self.model, RandomForestClassifier) or isinstance(self.model, GradientBoostingClassifier) :
            # use Tree Explainer SHAP to explain test set predictions
            explainer = shap.Explainer(self.predict, x_train)
            shap_values = explainer.shap_values(x_test)
            
            if plot :
                # Display SHAP's summary plot
                shap.summary_plot(shap_values,x_test)
            
            # Display SHAP's dependance plot
            if feature != None : 
                shap.dependence_plot(feature, shap_values, x_test)

            return shap_values
        
        elif isinstance(self.model, FastKernelSurvivalSVM) or isinstance(self.model, SVC):
            # use Kernel SHAP to explain test set predictions
            explainer = shap.KernelExplainer(self.predict, x_train)
            shap_values = explainer.shap_values(x_test)

            # Display SHAP's summary plot
            shap.summary_plot(shap_values,x_test)
            
            # Display SHAP's dependance plot
            if feature != None : 
                shap.dependence_plot(feature, shap_values, x_test)

            return shap_values

    def permutation_importance_feature(self, x_test, Y_test, Y_train, n_permutations=100):
        baseline_score = self.score(x_test, Y_test, Y_train)[self.scorer]
        feature_importance = {}
        
        for i in range(x_test.shape[1]):
            scores = []
            for j in range(n_permutations):
                x_permuted = x_test.copy()
                x_permuted.iloc[:, i] = np.random.permutation(x_permuted.iloc[:, i])
                score = self.score(x_permuted, Y_test, Y_train)[self.scorer]
                scores.append(score)
            feature_importance[x_test.columns[i]] = abs(baseline_score - np.mean(scores))
        
        return feature_importance


    def get_pi_values(self, x_test, Y_test, Y_train, plot = False, n_permutations = 100):
        feature_importances = self.permutation_importance_feature(x_test, Y_test, Y_train, n_permutations)
        sorted_importances = {k: v for k, v in sorted(feature_importances.items(), key=lambda item: item[1])}
        total_importance = sum(sorted_importances.values())
        importance_percent = {k: v / total_importance * 100 for k, v in sorted_importances.items()}
        n_features = len(feature_importances)
        if plot:
            # Display
            plt.figure(figsize=(10, 6))
            plt.barh(list(importance_percent.keys())[n_features-10:], list(importance_percent.values())[n_features-10:], color='blue')
            plt.xlabel('Importance (%)')
            plt.ylabel('Variable')
            plt.title('Variables importance')
            plt.show()

        return importance_percent
            
        

    def get_lime_explanation(self, x_train, x_test, Y_test, sample_idx = [0]):
        if isinstance(self.model, RandomForestClassifier) or isinstance(self.model, GradientBoostingClassifier) or isinstance(self.model, SVC):        
            # Initialize LIME explainer
            explainer = lime.lime_tabular.LimeTabularExplainer(training_data = x_train.values, feature_names = x_train.columns, class_names=['event_1','event_0'], mode = 'classification')

            for index in sample_idx:
                # Select instance to explain
                instance_idx = index

                # Explain prediction
                explanation = explainer.explain_instance(x_test.iloc[instance_idx], self.model.predict_proba)

                # Show explanation
                print(f'Lime explanation for index {index}')
                explanation.show_in_notebook()
        
        elif isinstance(self.model, FastKernelSurvivalSVM):
            # Let's use LIME method to interpret this model, since the output of a SSVM is a risk score, we can use classical lime method to interpret it

            # Initialize the explainer
            explainer = lime.lime_tabular.LimeTabularExplainer(training_data = x_train.values, mode='regression', feature_names=x_train.columns)

            for index in sample_idx:

                # Explain the sample
                explanation = explainer.explain_instance(x_test.iloc[index], self.predict, num_features=5)

                # Display the explanation
                print('Sample  : ', index)
                print('Time   : ', Y_test[index][1])
                print('Event   :'  , Y_test[index][0])
                explanation.show_in_notebook()
                
    """
---------------   Hyperparameters importance   ---------------
    """

    def hyperparameters_importances(self, hyperparameters_range, x, Y, n_trials = 50, n_samples = 15, plot = False):
        self.params = {}
        base_score = self.k_fold_cross_validation(x,Y)
        importance = {}
        n_sim = n_trials * n_samples * (len(hyperparameters_range) - 1)
        progress_bar = tqdm(total = n_sim, desc = "Progress of Hyperparameters importance")
        for param_name, param_values in hyperparameters_range.items():
            print(f'Hyperparameters : {param_name}')
            importance[param_name] = 0
            score_diff = 0
            if isinstance(param_values[0], str):
                score_diff = 0
                for value in param_values:
                    for _ in range(n_trials):
                        # Create random params for others hyperparameters
                        hyp_dict = {key: value for key, value in hyperparameters_range.items() if key != param_name}
                        self.params = random_number_dict(hyp_dict)
                        
                        # Create a model with the specified hyperparameter
                        self.params[param_name] = value

                        # Evaluate the modified model:
                        modified_score = self.k_fold_cross_validation(x, Y)
                        diff = base_score - modified_score
                        score_diff += abs(diff)
                        progress_bar.update(1)

                    # Average score differences over trials
                    score_difference_avg = score_diff / n_trials

                    # Add the average score difference to the hyperparameter's importance
                    importance[param_name] += abs(score_difference_avg)

            elif isinstance(param_values[0], int) : 
                step = int((param_values[1] + 1 - param_values[0])/n_samples) + 1
                for value in range(param_values[0], param_values[1] + 1, step):
                    for _ in range(n_trials):   
                        # Create random params for others hyperparameters
                        hyp_dict = {key: value for key, value in hyperparameters_range.items() if key != param_name}
                        self.params = random_number_dict(hyp_dict)
                        
                        # Create a model with the specified hyperparameter
                        self.params[param_name] = value

                        # Evaluate the modified model:
                        modified_score = self.k_fold_cross_validation(x, Y)
                        diff = base_score - modified_score
                        score_diff += abs(diff)
                        progress_bar.update(1)

                    # Average score differences over trials
                    score_difference_avg = score_diff / n_trials

                    # Add the average score difference to the hyperparameter's importance
                    importance[param_name] += abs(score_difference_avg)

            elif isinstance(param_values[0], float):
                for value in generate_random_numbers(param_values[0], param_values[1], n_samples):
                    for _ in range(n_trials):
                        # Create random params for others hyperparameters
                        hyp_dict = {key: value for key, value in hyperparameters_range.items() if key != param_name}
                        self.params = random_number_dict(hyp_dict)
                        
                        # Create a model with the specified hyperparameter
                        self.params[param_name] = value

                        # Evaluate the modified model:
                        modified_score = self.k_fold_cross_validation(x, Y)
                        diff = base_score - modified_score
                        score_diff += abs(diff)
                        progress_bar.update(1)

                    # Average score differences over trials
                    score_difference_avg = score_diff / n_trials

                    # Add the average score difference to the hyperparameter's importance
                    importance[param_name] += abs(score_difference_avg)
        
        progress_bar.close()
        sorted_importance = {k: v for k, v in sorted(importance.items(), key=lambda item: item[1])}

        if plot : 
            # Display
            plt.figure(figsize=(10, 6))
            plt.barh(list(sorted_importance.keys()), list(sorted_importance.values()), color='blue')
            plt.xlabel('Importance')
            plt.ylabel('Hyperparameters')
            plt.title('Hyperparameters importance')
            plt.show()

        self.params = self.best_params
        self.fit(x,Y)
        return sorted_importance
    
    """
---------------   Learning curve   ---------------
    """
    
    def learning_curve(self, X_train, X_test, y_train, y_test, n_sample = 25):
        train_sizes, train_scores, test_scores = [], [], []

        # Iterate over different training set sizes
        for train_size in np.linspace(0.1, 1.0, n_sample):
            train_size = int(train_size * len(X_train))
            train_sizes.append(train_size)

            # Train the model on a subset of the training set
            self.fit(X_train[:train_size], y_train[:train_size])

            # Calculate scores on the training and validation sets
            train_score = self.score(X_train[:train_size], y_train[:train_size], y_train[:train_size])[self.scorer]
            test_score = self.score(X_test, y_test, y_train)[self.scorer]

            train_scores.append(train_score)
            test_scores.append(test_score)

        # Plot the learning curve
        plt.figure()
        plt.title("Learning Curve")
        plt.xlabel("Training Examples")
        plt.ylabel("Score")
        plt.plot(train_sizes, train_scores, 'o-', color="r", label="Training score")
        plt.plot(train_sizes, test_scores, 'o-', color="g", label="Validation score")
        plt.legend(loc="best")
        plt.grid(True)
        plt.show()


    """
---------------   Accuracy of use of machine learning in survival analysis   ---------------
    """

    def compare_to_cox(self, df, x_test, Y_test, Y_train):

        if isinstance(self.model, GradientBoostingSurvivalAnalysis) or isinstance(self.model, RandomSurvivalForest) or isinstance(self.model, FastKernelSurvivalSVM):
            # Create Cox model
            cox_model = CoxPHFitter()

            # Fit the model, for each categorical feature we drop the last category to avoid multicollinearity
            cox_model.fit(df, duration_col=self.time, event_col=self.event)

            if hasattr(Scorer, self.scorer):
                cox_score = concordance_index(df[self.time], -cox_model.predict_partial_hazard(df), df[self.event])
            else:
                raise ValueError(f"Metric '{self.scorer}' is not a valid scoring metric./n Metric possible are : concordance_index_censured, concordance_index_ipcw, cumulative_dynamic_auc and ibs")
            
            # Compare the two models : 
            ml_score = self.score(x_test, Y_test, Y_train, metrics = ['concordance_censored'])['concordance_censored']

            diff = abs(ml_score - cox_score)
            diff_relative = diff/cox_score
            if ml_score > cox_score :
                print(f'The ML model improves the score of',Fore.RED + f'{round(diff_relative * 100,2)}%' + Style.RESET_ALL, 'compared with the Cox model, with a score of {round(ml_score,3)} for the ML model and {round(cox_score,3)} for cox model')
            else : 
                print(f'the Cox model should be preferred to the ML model, as the latter has a lower score of around',Fore.RED + f'{round(diff_relative * 100,2)}%' + Style.RESET_ALL,'compared with the Cox model, with a score of {round(ml_score,3)} for the ML model and {round(cox_score,3)} for cox model')

        else :
            raise ValueError('No comparaison with Cox for Classification models')