In [1]:
import numpy as np
import pandas as pd
from survivors.external import ClassifWrapSA, RegrWrapSA, SAWrapSA

In [2]:
from sklearn.metrics import mean_squared_error, r2_score, roc_auc_score

# Метрики регрессии
# - RMSE ожидаемого времени
# - R^2 ожидаемого времени

rmse_exp_time = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: np.sqrt(mean_squared_error(y_tst["time"], pred_time))
r2_exp_time = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: r2_score(y_tst["time"], pred_time)

In [3]:
from sklearn.metrics import roc_auc_score, log_loss, mean_squared_error
# Метрики классификации
# - AUC вероятности события
# - log-loss вероятности события
# - rmse исхода

def find_sf_at_truetime(pred_sf, event_time, bins):
    idx_pred = np.clip(np.searchsorted(bins, event_time), 0, len(bins) - 1)
    proba = np.take_along_axis(pred_sf, idx_pred[:, np.newaxis], axis=1).squeeze()
    return proba

## example
# true_times = np.array([1, 19, 21, 31])
# bins = np.array([10,20,30])
# sf = np.array([[0.9, 0.8, 0.7], 
#                [0.7, 0.6, 0.5], 
#                [0.5, 0.4, 0.3],
#                [0.05, 0.04, 0.03]])
# print(find_sf_at_truetime(sf, true_times, bins))  # [0.9  0.6  0.3  0.03] 

auc_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: roc_auc_score(y_tst["cens"].astype(int), find_sf_at_truetime(pred_sf, y_tst["time"], bins))
log_loss_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: log_loss(y_tst["cens"], find_sf_at_truetime(pred_sf, y_tst["time"], bins))
rmse_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: np.sqrt(mean_squared_error(y_tst["cens"], find_sf_at_truetime(pred_sf, y_tst["time"], bins)))

In [4]:
from survivors.experiments import grid as exp
import survivors.datasets as ds

l_metrics = ["CI", "IBS", "AUPRC", "RMSE_TIME", "R2_TIME", "AUC_EVENT", "LOGLOSS_EVENT", "RMSE_EVENT"]
X, y, features, categ, _ = ds.load_pbc_dataset()
experim = exp.Experiments(folds=5, mode="CV+SAMPLE")
experim.add_new_metric("RMSE_TIME", rmse_exp_time)
experim.add_new_metric("R2_TIME", r2_exp_time)
experim.add_new_metric("AUC_EVENT", auc_event)
experim.add_new_metric("LOGLOSS_EVENT", log_loss_event)
experim.add_new_metric("RMSE_EVENT", rmse_event)
experim.set_metrics(l_metrics)

In [None]:
# Гиперпараметры и модели классификации

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

CLASS_PARAM_GRIDS = {
    "logistic_regression": {
        "penalty": ["l2"],
        "C": [0.01, 0.1, 1, 10],
        "solver": ["liblinear", "lbfgs"],
        "class_weight": [None, "balanced"],
        "max_iter": [1000],
    },
    "svc": {
        "kernel": ["linear", "rbf"],
        "C": [0.1, 1, 10],
        "class_weight": [None, "balanced"],
        "probability": [True],
    },
    "knn_classifier": {
        "n_neighbors": [5, 10, 20],
        "weights": ["uniform", "distance"],
    },
    "decision_tree_classifier": {
        "max_depth": [5, 10, 20],
        "min_samples_split": [2, 10],
        "min_samples_leaf": [1, 5],
        "criterion": ["gini", "entropy"],
    },
    "random_forest_classifier": {
        "n_estimators": [100, 300],
        "max_depth": [10, 30],
        "min_samples_split": [2, 10],
        "min_samples_leaf": [1, 5],
    },
    "gradient_boosting_classifier": {
        "n_estimators": [100, 300],
        "learning_rate": [0.05, 0.1],
        "max_depth": [2, 3],
        "subsample": [0.7, 1.0],
    }
}

experim.add_method(ClassifWrapSA(LogisticRegression()), CLASS_PARAM_GRIDS['logistic_regression'])
experim.add_method(ClassifWrapSA(SVC()), CLASS_PARAM_GRIDS['svc'])
experim.add_method(ClassifWrapSA(KNeighborsClassifier()), CLASS_PARAM_GRIDS['knn_classifier'])
experim.add_method(ClassifWrapSA(DecisionTreeClassifier()), CLASS_PARAM_GRIDS['decision_tree_classifier'])
experim.add_method(ClassifWrapSA(RandomForestClassifier()), CLASS_PARAM_GRIDS['random_forest_classifier'])
experim.add_method(ClassifWrapSA(GradientBoostingClassifier()), CLASS_PARAM_GRIDS['gradient_boosting_classifier'])

In [6]:
# Гиперпараметры и модели регрессии

from sklearn.linear_model import ElasticNet
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

REGR_PARAM_GRIDS = {
    "elastic_net": {
        "alpha": [0.001, 0.01, 0.1],
        "l1_ratio": [0.2, 0.5, 0.8],
        "max_iter": [1000, 5000],
    }, 
    "decision_tree_regressor": {
        "max_depth": [5, 10, 20],
        "min_samples_split": [2, 10],
        "min_samples_leaf": [1, 5],
        "criterion": ["squared_error", "friedman_mse"],
    },
    "random_forest_regressor": {
        "n_estimators": [100, 300],
        "max_depth": [10, 30],
        "min_samples_split": [2, 10],
        "min_samples_leaf": [1, 5],
    },
    "gradient_boosting_regressor": {
        "n_estimators": [100, 300],
        "learning_rate": [0.05, 0.1],
        "max_depth": [2, 3],
        "subsample": [0.7, 1.0],
    },
    "svr": {
        "kernel": ["linear", "rbf"],
        "C": [0.1, 1, 10],
        "epsilon": [0.1, 0.2],
    },
    "knn_regressor": {
        "n_neighbors": [5, 10, 20],
        "weights": ["uniform", "distance"],
    }
}

experim.add_method(RegrWrapSA(ElasticNet()), REGR_PARAM_GRIDS['elastic_net'])
experim.add_method(RegrWrapSA(DecisionTreeRegressor()), REGR_PARAM_GRIDS['decision_tree_regressor'])
experim.add_method(RegrWrapSA(RandomForestRegressor()), REGR_PARAM_GRIDS['random_forest_regressor'])
experim.add_method(RegrWrapSA(GradientBoostingRegressor()), REGR_PARAM_GRIDS['gradient_boosting_regressor'])
experim.add_method(RegrWrapSA(SVR()), REGR_PARAM_GRIDS['svr'])
experim.add_method(RegrWrapSA(KNeighborsRegressor()), REGR_PARAM_GRIDS['knn_regressor'])

In [7]:
# Гиперпараметры моделей выживаемости (внешние для survivors)

from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.tree import SurvivalTree
from sksurv.ensemble import RandomSurvivalForest
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from lifelines import KaplanMeierFitter

EXTERNAL_SURV_PARAM_GRIDS = {
    "km": {},
    "cox_ph": {
        'alpha': [100, 10, 1, 0.1, 0.01, 0.001],
        'ties': ["breslow"]
    },
    "random_survival_forest": {
        'n_estimators': [50],
        'max_depth': [5, 20],
        'min_samples_leaf': [0.001, 0.01, 0.1, 0.25],
        "random_state": [123]
    },
    "survival_tree": {
        'max_depth': [None, 20, 30],
        'min_samples_leaf': [1, 10, 20],
        'max_features': [None, "sqrt"],
        "random_state": [123]
    },
    "gbds": {
        'loss': ["coxph"],
        'learning_rate': [0.01, 0.05, 0.1, 0.5],
        'n_estimators': [50],
        'min_samples_leaf': [1, 10, 50, 100],
        'max_features': ["sqrt"],
        "random_state": [123]
    },
}

# experim.add_method(SAWrapSA(KaplanMeierFitter()), EXTERNAL_SURV_PARAM_GRIDS['km'])
# experim.add_method(CoxPHSurvivalAnalysis, EXTERNAL_SURV_PARAM_GRIDS['cox_ph'])
# experim.add_method(RandomSurvivalForest, EXTERNAL_SURV_PARAM_GRIDS['random_survival_forest'])
# experim.add_method(SurvivalTree, EXTERNAL_SURV_PARAM_GRIDS['survival_tree'])
# experim.add_method(GradientBoostingSurvivalAnalysis, EXTERNAL_SURV_PARAM_GRIDS['gbds'])

In [8]:
# Гиперпараметры моделей выживаемости (внутри survivors)

from survivors.tree import CRAID
from survivors.ensemble import ParallelBootstrapCRAID

INTERNAL_SURV_PARAM_GRIDS = {
    "CRAID": {
        "depth": [10],
        "criterion": ["wilcoxon", "logrank"],
        "l_reg": [0, 0.01, 0.1],
        "min_samples_leaf": [0.05, 0.01, 0.001],
        "categ": [categ]
    },
    "ParallelBootstrapCRAID": {
        "n_estimators": [50],
        "depth": [7],
        "size_sample": [0.3, 0.7],
        "l_reg": [0, 0.01, 0.1],
        "criterion": ["tarone-ware", "wilcoxon"],
        "min_samples_leaf": [0.05, 0.01],
        "ens_metric_name": ["IBS_REMAIN"],
        "max_features": ["sqrt"],
        "categ": [categ]
    }
}

# experim.add_method(CRAID, INTERNAL_SURV_PARAM_GRIDS["CRAID"])
# experim.add_method(ParallelBootstrapCRAID, INTERNAL_SURV_PARAM_GRIDS["ParallelBootstrapCRAID"])

In [None]:
import warnings
warnings.filterwarnings("ignore")

experim.run_effective(X, y, verbose=0, stratify_best=[])
df_results = experim.get_best_by_mode()

<survivors.external.mlwrap.ClassifWrapSA object at 0x168e8b7d0> {'penalty': ['l2'], 'C': [0.01, 0.1, 1, 10], 'solver': ['liblinear', 'lbfgs'], 'class_weight': [None, 'balanced'], 'max_iter': [1000]}
<survivors.external.mlwrap.ClassifWrapSA object at 0x1682fbd90> {'kernel': ['linear', 'rbf'], 'C': [0.1, 1, 10], 'class_weight': [None, 'balanced'], 'probability': [True]}


In [None]:
df_results

Unnamed: 0,METHOD,PARAMS,TIME,CI,IBS,AUPRC,MSE,CRIT,TIMES,MEMS,MEM,CI_mean,IBS_mean,AUPRC_mean,MSE_mean
0,ClassifWrapSA(KNeighborsClassifier),{'n_neighbors': 5},1.234413,"[0.5110389610389611, 0.5311688311688312, 0.572...","[0.3177899785972074, 0.23926117737863478, 0.22...","[0.4071347142857144, 0.4892759285714288, 0.417...","[2875195.4971428574, 3070796.2092857147, 22593...",,"[0.23901057243347168, 0.24139070510864258, 0.2...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.544221,0.261694,0.428563,2672510.0
1,ClassifWrapSA(RandomForestClassifier),{'n_estimators': 10},1.524835,"[0.5834415584415584, 0.6704545454545454, 0.630...","[0.20812438406816028, 0.17427406454265962, 0.1...","[0.4482053214285715, 0.46070507142857137, 0.46...","[2421587.9942857143, 1739820.9342857148, 23849...",,"[0.24786972999572754, 0.3294212818145752, 0.32...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.626364,0.190776,0.455705,2142657.0
2,RegrWrapSA(KNeighborsRegressor),{'n_neighbors': 5},1.414418,"[0.5704545454545454, 0.5581168831168831, 0.608...","[0.42874361542697326, 0.3189030111692633, 0.30...","[0.679910643939394, 0.7332104437229439, 0.7253...","[1424609.1785714286, 1186656.4285714286, 11469...",,"[0.23729300498962402, 0.3060445785522461, 0.28...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.571558,0.358311,0.708013,1299685.0
3,SAWrapSA(KaplanMeierFitter),{},1.394307,"[0.5, 0.5, 0.5, 0.5, 0.5]","[0.20329872878125196, 0.18216530965070882, 0.1...","[0.6078425806496679, 0.6005546928734674, 0.607...","[1976689.964480705, 1788654.2615088334, 219485...",,"[0.24472260475158691, 0.27121639251708984, 0.3...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.5,0.192197,0.605376,1953460.0
4,SAWrapSA(CoxPHSurvivalAnalysis),{'alpha': 100},1.678197,"[0.6746753246753247, 0.7032467532467532, 0.684...","[0.1401302204109691, 0.11912412566025657, 0.11...","[0.6891068122578905, 0.6969656045797799, 0.699...","[1842316.951513259, 1577546.2837628557, 176631...",,"[0.3267405033111572, 0.2817802429199219, 0.326...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.668701,0.134738,0.693439,1781999.0
5,SAWrapSA(CRAID),{'depth': 10},7.10268,"[0.6451298701298701, 0.662987012987013, 0.6266...","[0.2031049643034096, 0.17609808101774893, 0.22...","[0.7365212149141546, 0.7340941337882605, 0.703...","[2244500.0687337946, 2075063.6883098758, 28734...",,"[1.4453179836273193, 1.3720934391021729, 1.352...","[5.76953125, 0.00390625, 0.0, 0.0078125, 0.011...",5.792969,0.642078,0.20999,0.715095,2433688.0
6,CRAID,{'depth': 10},8.274698,"[0.6668831168831169, 0.688961038961039, 0.5922...","[0.21673855633594108, 0.11971061681469919, 0.1...","[0.7335604735367447, 0.797497639607317, 0.7609...","[1027434.0053542962, 745379.9369352116, 127947...",,"[1.6378607749938965, 1.746692180633545, 1.6171...","[0.02734375, 0.0, 0.0, 0.0, 0.0078125]",0.035156,0.639481,0.191457,0.758411,1069884.0
7,ClassifWrapSA(KNeighborsClassifier),{'n_neighbors': 5},1.386411,"[0.5110389610389611, 0.5311688311688312, 0.572...","[0.3177899785972074, 0.23926117737863478, 0.22...","[0.4071347142857144, 0.4892759285714288, 0.417...","[2875195.4971428574, 3070796.2092857147, 22593...",,"[0.23865318298339844, 0.24819111824035645, 0.2...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.544221,0.261694,0.428563,2672510.0
8,ClassifWrapSA(RandomForestClassifier),{'n_estimators': 10},1.494038,"[0.5613636363636364, 0.6467532467532467, 0.630...","[0.23497124837875813, 0.1591442549339761, 0.17...","[0.46963346428571423, 0.48213321428571404, 0.4...","[2936870.331428572, 1890204.7071428576, 232979...",,"[0.29004549980163574, 0.32977914810180664, 0.2...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.608247,0.196236,0.462848,2322670.0
9,RegrWrapSA(KNeighborsRegressor),{'n_neighbors': 5},1.264567,"[0.5704545454545454, 0.5581168831168831, 0.608...","[0.42874361542697326, 0.3189030111692633, 0.30...","[0.679910643939394, 0.7332104437229439, 0.7253...","[1424609.1785714286, 1186656.4285714286, 11469...",,"[0.2384166717529297, 0.2931070327758789, 0.243...","[0.0, 0.0, 0.0, 0.0, 0.0]",0.0,0.571558,0.358311,0.708013,1299685.0
