In [None]:
import numpy as np
import pandas as pd
from survivors.external import ClassifWrapSA, RegrWrapSA, SAWrapSA

In [2]:
from sklearn.metrics import root_mean_squared_error, r2_score

# Метрики регрессии
# - RMSE ожидаемого времени
# - R^2 ожидаемого времени

rmse_exp_time = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: root_mean_squared_error(y_tst["time"], pred_time)
r2_exp_time = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: r2_score(y_tst["time"], pred_time)

In [None]:
from sklearn.metrics import root_mean_squared_error, roc_auc_score, log_loss
# Метрики классификации
# - AUC вероятности события
# - log-loss вероятности события
# - RMSE исхода

def find_sf_at_truetime(pred_sf, event_time, bins):
    idx_pred = np.clip(np.searchsorted(bins, event_time), 0, len(bins) - 1)
    proba = np.take_along_axis(pred_sf, idx_pred[:, np.newaxis], axis=1).squeeze()
    return proba

## example
# true_times = np.array([1, 19, 21, 31])
# bins = np.array([10,20,30])
# sf = np.array([[0.9, 0.8, 0.7], 
#                [0.7, 0.6, 0.5], 
#                [0.5, 0.4, 0.3],
#                [0.05, 0.04, 0.03]])
# print(find_sf_at_truetime(sf, true_times, bins))  # [0.9  0.6  0.3  0.03] 

auc_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: roc_auc_score(y_tst["cens"].astype(int), find_sf_at_truetime(pred_sf, y_tst["time"], bins))
log_loss_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: log_loss(y_tst["cens"], find_sf_at_truetime(pred_sf, y_tst["time"], bins))
rmse_event = lambda y_tr, y_tst, pred_time, pred_sf, pred_hf, bins: root_mean_squared_error(y_tst["cens"], find_sf_at_truetime(pred_sf, y_tst["time"], bins))

In [4]:
from survivors.experiments import grid as exp
import survivors.datasets as ds

l_metrics = ["CI", "IBS", "AUPRC", "RMSE_TIME", "R2_TIME", "AUC_EVENT", "LOGLOSS_EVENT", "RMSE_EVENT"]
X, y, features, categ, _ = ds.load_pbc_dataset()
experim = exp.Experiments(folds=5, mode="CV+SAMPLE")
experim.add_new_metric("RMSE_TIME", rmse_exp_time)
experim.add_new_metric("R2_TIME", r2_exp_time)
experim.add_new_metric("AUC_EVENT", auc_event)
experim.add_new_metric("LOGLOSS_EVENT", log_loss_event)
experim.add_new_metric("RMSE_EVENT", rmse_event)
experim.set_metrics(l_metrics)

In [5]:
# Гиперпараметры и модели классификации

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

CLASS_PARAM_GRIDS = {
    "logistic_regression": dict(),
    "svc": dict(),
    "knn_classifier": dict(),
    "decision_tree_classifier": dict(),
    "random_forest_classifier": dict(),
    "gradient_boosting_classifier": dict()
}

# CLASS_PARAM_GRIDS = {
#     "logistic_regression": {
#         "penalty": ["l2"],
#         "C": [0.01, 0.1, 1, 10],
#         "solver": ["liblinear", "lbfgs"],
#         "class_weight": [None, "balanced"],
#         "max_iter": [1000],
#     },
#     "svc": {
#         "kernel": ["linear", "rbf"],
#         "C": [0.1, 1, 10],
#         "class_weight": [None, "balanced"],
#         "probability": [True],
#     },
#     "knn_classifier": {
#         "n_neighbors": [5, 10, 20],
#         "weights": ["uniform", "distance"],
#     },
#     "decision_tree_classifier": {
#         "max_depth": [5, 10, 20],
#         "min_samples_split": [2, 10],
#         "min_samples_leaf": [1, 5],
#         "criterion": ["gini", "entropy"],
#     },
#     "random_forest_classifier": {
#         "n_estimators": [100, 300],
#         "max_depth": [10, 30],
#         "min_samples_split": [2, 10],
#         "min_samples_leaf": [1, 5],
#     },
#     "gradient_boosting_classifier": {
#         "n_estimators": [100, 300],
#         "learning_rate": [0.05, 0.1],
#         "max_depth": [2, 3],
#         "subsample": [0.7, 1.0],
#     }
# }

experim.add_method(ClassifWrapSA(LogisticRegression()), CLASS_PARAM_GRIDS['logistic_regression'])
experim.add_method(ClassifWrapSA(SVC()), CLASS_PARAM_GRIDS['svc'])
experim.add_method(ClassifWrapSA(KNeighborsClassifier()), CLASS_PARAM_GRIDS['knn_classifier'])
experim.add_method(ClassifWrapSA(DecisionTreeClassifier()), CLASS_PARAM_GRIDS['decision_tree_classifier'])
experim.add_method(ClassifWrapSA(RandomForestClassifier()), CLASS_PARAM_GRIDS['random_forest_classifier'])
experim.add_method(ClassifWrapSA(GradientBoostingClassifier()), CLASS_PARAM_GRIDS['gradient_boosting_classifier'])

In [6]:
# Гиперпараметры и модели регрессии

from sklearn.linear_model import ElasticNet
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

REGR_PARAM_GRIDS = {
    "elastic_net": dict(),
    "decision_tree_regressor": dict(),
    "random_forest_regressor": dict(),
    "gradient_boosting_regressor": dict(),
    "svr": dict(),
    "knn_regressor": dict()
}

# REGR_PARAM_GRIDS = {
#     "elastic_net": {
#         "alpha": [0.001, 0.01, 0.1],
#         "l1_ratio": [0.2, 0.5, 0.8],
#         "max_iter": [1000, 5000],
#     }, 
#     "decision_tree_regressor": {
#         "max_depth": [5, 10, 20],
#         "min_samples_split": [2, 10],
#         "min_samples_leaf": [1, 5],
#         "criterion": ["squared_error", "friedman_mse"],
#     },
#     "random_forest_regressor": {
#         "n_estimators": [100, 300],
#         "max_depth": [10, 30],
#         "min_samples_split": [2, 10],
#         "min_samples_leaf": [1, 5],
#     },
#     "gradient_boosting_regressor": {
#         "n_estimators": [100, 300],
#         "learning_rate": [0.05, 0.1],
#         "max_depth": [2, 3],
#         "subsample": [0.7, 1.0],
#     },
#     "svr": {
#         "kernel": ["linear", "rbf"],
#         "C": [0.1, 1, 10],
#         "epsilon": [0.1, 0.2],
#     },
#     "knn_regressor": {
#         "n_neighbors": [5, 10, 20],
#         "weights": ["uniform", "distance"],
#     }
# }

experim.add_method(RegrWrapSA(ElasticNet()), REGR_PARAM_GRIDS['elastic_net'])
experim.add_method(RegrWrapSA(DecisionTreeRegressor()), REGR_PARAM_GRIDS['decision_tree_regressor'])
experim.add_method(RegrWrapSA(RandomForestRegressor()), REGR_PARAM_GRIDS['random_forest_regressor'])
experim.add_method(RegrWrapSA(GradientBoostingRegressor()), REGR_PARAM_GRIDS['gradient_boosting_regressor'])
experim.add_method(RegrWrapSA(SVR()), REGR_PARAM_GRIDS['svr'])
experim.add_method(RegrWrapSA(KNeighborsRegressor()), REGR_PARAM_GRIDS['knn_regressor'])

In [7]:
# Гиперпараметры моделей выживаемости (внешние для survivors)

from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.tree import SurvivalTree
from sksurv.ensemble import RandomSurvivalForest
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from lifelines import KaplanMeierFitter

EXTERNAL_SURV_PARAM_GRIDS = {
    "km": dict(),
    "cox_ph": dict(),
    "random_survival_forest": dict(),
    "survival_tree": dict(),
    "gbsa": dict(),
}

# EXTERNAL_SURV_PARAM_GRIDS = {
#     "km": {},
#     "cox_ph": {
#         'alpha': [100, 10, 1, 0.1, 0.01, 0.001],
#         'ties': ["breslow"]
#     },
#     "random_survival_forest": {
#         'n_estimators': [50],
#         'max_depth': [5, 20],
#         'min_samples_leaf': [0.001, 0.01, 0.1, 0.25],
#         "random_state": [123]
#     },
#     "survival_tree": {
#         'max_depth': [None, 20, 30],
#         'min_samples_leaf': [1, 10, 20],
#         'max_features': [None, "sqrt"],
#         "random_state": [123]
#     },
#     "gbsa": {
#         'loss': ["coxph"],
#         'learning_rate': [0.01, 0.05, 0.1, 0.5],
#         'n_estimators': [50],
#         'min_samples_leaf': [1, 10, 50, 100],
#         'max_features': ["sqrt"],
#         "random_state": [123]
#     },
# }

experim.add_method(SAWrapSA(KaplanMeierFitter()), EXTERNAL_SURV_PARAM_GRIDS['km'])
experim.add_method(CoxPHSurvivalAnalysis, EXTERNAL_SURV_PARAM_GRIDS['cox_ph'])
experim.add_method(RandomSurvivalForest, EXTERNAL_SURV_PARAM_GRIDS['random_survival_forest'])
experim.add_method(SurvivalTree, EXTERNAL_SURV_PARAM_GRIDS['survival_tree'])
experim.add_method(GradientBoostingSurvivalAnalysis, EXTERNAL_SURV_PARAM_GRIDS['gbsa'])

In [8]:
# Гиперпараметры моделей выживаемости (внутри survivors)

from survivors.tree import CRAID
from survivors.ensemble import ParallelBootstrapCRAID

INTERNAL_SURV_PARAM_GRIDS = {
    "CRAID": dict(),
    "ParallelBootstrapCRAID": dict()
}

# INTERNAL_SURV_PARAM_GRIDS = {
#     "CRAID": {
#         "depth": [10],
#         "criterion": ["wilcoxon", "logrank"],
#         "l_reg": [0, 0.01, 0.1],
#         "min_samples_leaf": [0.05, 0.01, 0.001],
#         "categ": [categ]
#     },
#     "ParallelBootstrapCRAID": {
#         "n_estimators": [50],
#         "depth": [7],
#         "size_sample": [0.3, 0.7],
#         "l_reg": [0, 0.01, 0.1],
#         "criterion": ["tarone-ware", "wilcoxon"],
#         "min_samples_leaf": [0.05, 0.01],
#         "ens_metric_name": ["IBS_REMAIN"],
#         "max_features": ["sqrt"],
#         "categ": [categ]
#     }
# }

experim.add_method(CRAID, INTERNAL_SURV_PARAM_GRIDS["CRAID"])
experim.add_method(ParallelBootstrapCRAID, INTERNAL_SURV_PARAM_GRIDS["ParallelBootstrapCRAID"])

In [9]:
import warnings
warnings.filterwarnings("ignore")

experim.run_effective(X, y, verbose=0, stratify_best=[])
df_results = experim.get_best_by_mode()

<survivors.external.mlwrap.ClassifWrapSA object at 0x15057db90> {}
<survivors.external.mlwrap.ClassifWrapSA object at 0x147fefd90> {}
<survivors.external.mlwrap.ClassifWrapSA object at 0x151282510> {}
<survivors.external.mlwrap.ClassifWrapSA object at 0x143bdf9d0> {}
<survivors.external.mlwrap.ClassifWrapSA object at 0x151283ed0> {}
<survivors.external.mlwrap.ClassifWrapSA object at 0x147ff4210> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x152025190> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x15057ef90> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x151282a50> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x14712c810> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x146f58c50> {}
<survivors.external.mlwrap.RegrWrapSA object at 0x1469fee90> {}
<survivors.external.mlwrap.SAWrapSA object at 0x106147ed0> {}
<class 'sksurv.linear_model.coxph.CoxPHSurvivalAnalysis'> {}
<class 'sksurv.ensemble.forest.RandomSurvivalForest'> {}
<class 'sksurv.tree.tree.SurvivalT

In [10]:
df_results

Unnamed: 0,METHOD,PARAMS,TIME,CI,IBS,AUPRC,RMSE_TIME,R2_TIME,AUC_EVENT,LOGLOSS_EVENT,...,RMSE_TIME_CV,RMSE_TIME_CV_mean,R2_TIME_CV,R2_TIME_CV_mean,AUC_EVENT_CV,AUC_EVENT_CV_mean,LOGLOSS_EVENT_CV,LOGLOSS_EVENT_CV_mean,RMSE_EVENT_CV,RMSE_EVENT_CV_mean
0,ClassifWrapSA(LogisticRegression),{},0.470072,"[0.6612715540040199, 0.65915582354808, 0.61779...","[0.179061436771067, 0.17601380346840126, 0.219...","[0.4225872894793855, 0.420086837993334, 0.4320...","[1235.4445277274824, 1233.2532537721427, 1428....","[-0.33611605189793026, -0.2338423600730739, -0...","[0.19778024417314094, 0.17869034406215317, 0.2...","[1.336142675632844, 1.4381445886165842, 1.3750...",...,"[1346.4918884534025, 1068.8949110684803, 1450....",1355.44628,"[-0.40549699231760283, -0.046221192267934974, ...",-0.565113,"[0.2598639455782313, 0.13636363636363638, 0.22...",0.204076,"[1.2662724614602978, 1.3764784932840821, 1.374...",1.494434,"[0.6669713158842351, 0.7112868934936267, 0.683...",0.699998
1,ClassifWrapSA(SVC),{},0.477243,"[0.5568602560033852, 0.5559081772982122, 0.540...","[0.2244593948049679, 0.21732513491458116, 0.24...","[0.4223144926045513, 0.4320257766693709, 0.419...","[1476.1428634797037, 1503.3768845724887, 1526....","[-0.9074557154320972, -0.8335420446665127, -1....","[0.3662597114317425, 0.27391786903440624, 0.33...","[0.9242056886408869, 0.9734559485849168, 0.937...",...,"[1634.1284992751732, 1410.2370192653952, 1576....",1511.736295,"[-1.0701167813775498, -0.8211164434307903, -1....",-0.927902,"[0.32925170068027215, 0.23796791443850268, 0.2...",0.292135,"[0.9259000149502187, 1.0057593538360963, 0.977...",0.961612,"[0.5963894347513112, 0.6255916218302449, 0.615...",0.609924
2,ClassifWrapSA(KNeighborsClassifier),{},0.588481,"[0.5869565217391305, 0.5571776155717761, 0.574...","[0.23356018498084488, 0.2195358213278248, 0.25...","[0.4405708985507247, 0.43332466666666664, 0.43...","[1603.8794883168368, 1581.0180791713503, 1582....","[-1.2518587634128857, -1.0278172293928227, -1....","[0.3577136514983352, 0.3281908990011099, 0.405...","[8.70652540749222, 6.703141526430811, 5.457617...",...,"[1695.6401437636634, 1752.368742384352, 1503.1...",1631.815474,"[-1.2288960637169226, -1.8119299005023835, -0....",-1.266331,"[0.4149659863945578, 0.31483957219251335, 0.27...",0.367818,"[5.234806580438786, 10.311550020725312, 4.7261...",6.034343,"[0.6290582530372572, 0.7206148168652337, 0.660...",0.656414
3,ClassifWrapSA(DecisionTreeClassifier),{},0.420431,"[0.5812969427694912, 0.5789696392679573, 0.572...","[0.29705787369623254, 0.31079246015241574, 0.3...","[0.49999, 0.44926637681159404, 0.4347739130434...","[2132.5922495635564, 2110.6362578718954, 2101....","[-2.9811939414564748, -2.613948215884461, -2.8...","[0.26392896781354047, 0.276803551609323, 0.307...","[27.163332988899878, 26.118589412403733, 25.07...",...,"[2327.9275884663716, 1871.103365282482, 2389.9...",2154.591521,"[-3.201084065483239, -2.205893225210997, -3.98...",-2.936194,"[0.30476190476190473, 0.322192513368984, 0.299...",0.306832,"[25.74546670651225, 25.101830038849446, 25.745...",25.874194,"[0.8451542547285166, 0.8345229603962802, 0.845...",0.847202
4,ClassifWrapSA(RandomForestClassifier),{},2.124683,"[0.6486829577911774, 0.6425473394689517, 0.617...","[0.17576135116404193, 0.1687545142641859, 0.18...","[0.4307884855072462, 0.441512908695652, 0.4539...","[1274.4009284559063, 1337.97452032129, 1382.09...","[-0.42170614260579997, -0.45228155796878755, -...","[0.19311875693673697, 0.1547169811320755, 0.14...","[1.6019827239921607, 1.7340897397013881, 1.672...",...,"[1523.7474670130125, 1261.9871542305245, 1437....",1389.527422,"[-0.7999002239248487, -0.45835536104804553, -0...",-0.642004,"[0.16122448979591836, 0.12967914438502676, 0.1...",0.13582,"[1.5475047897427807, 2.2186257519662145, 1.577...",1.68791,"[0.7238007322461065, 0.7491924223550888, 0.728...",0.730306
5,ClassifWrapSA(GradientBoostingClassifier),{},2.319716,"[0.6105469163228605, 0.6411721146725907, 0.612...","[0.23219125438474147, 0.20667665184153478, 0.2...","[0.43189528704314145, 0.4613706711685964, 0.49...","[1637.7710700090954, 1599.8619600205016, 1701....","[-1.3480323286714841, -1.0764437031161305, -1....","[0.21620421753607105, 0.15205327413984462, 0.1...","[1.9560733430041737, 2.1711591838571076, 2.099...",...,"[1810.8705761670406, 1565.03520651318, 1642.13...",1662.455813,"[-1.542127109276413, -1.2428576854856819, -1.3...",-1.354985,"[0.13197278911564625, 0.12433155080213905, 0.1...",0.129908,"[2.2029597576908873, 2.371167908127635, 2.1570...",2.283121,"[0.7868597771982991, 0.8144917265479706, 0.794...",0.804541
6,RegrWrapSA(ElasticNet),{},0.396897,"[0.6674071723262457, 0.6703691949645615, 0.705...","[0.29499853581487623, 0.3093619273392539, 0.29...","[0.7308081034987557, 0.7113018328209633, 0.754...","[950.6549553875468, 991.0049073790076, 906.365...","[0.2088779046808723, 0.20327811564256082, 0.28...","[0.6125416204217536, 0.6148723640399556, 0.565...","[14.365224176822052, 14.104038282698015, 16.45...",...,"[913.9865386629842, 828.6610973647274, 837.951...",885.001444,"[0.35240716287595963, 0.3712074383625019, 0.38...",0.338789,"[0.5809523809523811, 0.6136363636363636, 0.621...",0.583937,"[15.447280023907352, 14.803643356244546, 14.80...",15.576007,"[0.6546536707079771, 0.6408699444616557, 0.640...",0.657082
7,RegrWrapSA(DecisionTreeRegressor),{},0.436924,"[0.6319686871892521, 0.6626467788003808, 0.640...","[0.294418163708312, 0.2571864011408699, 0.3188...","[0.6720331115502854, 0.6930764617186358, 0.666...","[1183.567575815483, 1159.3907047063212, 1149.8...","[-0.22626349408549218, -0.09047261753881863, -...","[0.6219755826859045, 0.5277469478357381, 0.535...","[14.104038282698015, 17.4994549063105, 16.9770...",...,"[1158.3171675816109, 1236.535370789807, 1294.9...",1269.077364,"[-0.04010557422341132, -0.40012426261403267, -...",-0.369444,"[0.519047619047619, 0.5601604278074865, 0.5307...",0.540807,"[17.37819002689577, 16.090916691570158, 17.378...",16.605826,"[0.6943650748294136, 0.6681531047810609, 0.694...",0.678424
8,RegrWrapSA(RandomForestRegressor),{},3.82151,"[0.6966571458796149, 0.7359039458373003, 0.720...","[0.2636575354710466, 0.26988738652747524, 0.26...","[0.76111052774118, 0.7453737615283268, 0.77230...","[897.4632062217031, 852.7946662319266, 871.742...","[0.2949320287411876, 0.41001076546463466, 0.33...","[0.6148723640399556, 0.6243063263041065, 0.586...","[14.104038282698015, 13.842852388573977, 15.67...",...,"[871.8372022016823, 840.737643807083, 858.2851...",880.378671,"[0.4107586289262515, 0.3527463912113461, 0.356...",0.344813,"[0.6047619047619047, 0.6885026737967914, 0.628...",0.623104,"[14.803643356244546, 12.22909668559332, 14.160...",14.160007,"[0.6408699444616557, 0.5824823725107175, 0.626...",0.626314
9,RegrWrapSA(GradientBoostingRegressor),{},1.878663,"[0.6631228181529674, 0.7061779329313446, 0.716...","[0.2813152632597307, 0.27626553549985494, 0.27...","[0.743434113599766, 0.7301493551456595, 0.7572...","[976.7153509654664, 922.8210705553214, 915.050...","[0.1649091820308457, 0.309139853148341, 0.2718...","[0.6066592674805772, 0.6054384017758047, 0.553...","[14.62641007094609, 14.365224176822053, 16.977...",...,"[875.9694017813309, 820.8842113190246, 875.417...",886.340655,"[0.4051598012237153, 0.3829543439508427, 0.330...",0.335706,"[0.6285714285714287, 0.5975935828877006, 0.620...",0.614963,"[14.16000668858174, 14.803643356244546, 14.160...",14.417461,"[0.6267831705280087, 0.6408699444616557, 0.626...",0.632041


In [15]:
pd.Series(y["time"]).describe()

count     418.000000
mean     1917.782297
std      1104.672992
min        41.000000
25%      1092.750000
50%      1730.000000
75%      2613.500000
max      4795.000000
dtype: float64