In [1]:
import pandas as pd
import numpy as np
import os
import time
import copy
import pathlib, tempfile

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from joblib import Parallel, delayed
from scipy import stats

from survivors import metrics as metr
from survivors import constants as cnt
from survivors import criteria as crit
from numba import njit, jit, int32, float64
from lifelines import KaplanMeierFitter, NelsonAalenFitter
from lifelines.utils import concordance_index

from survivors.ensemble import BootstrapCRAID
import survivors.datasets as ds

import cProfile
import pstats

%load_ext line_profiler
%load_ext scalene

Scalene extension successfully loaded. Note: Scalene currently only
supports CPU+GPU profiling inside Jupyter notebooks. For full Scalene
profiling, use the command line version.


In [186]:
import numpy as np
from numba import njit

from scipy import stats
from survivors.tree.stratified_model import KaplanMeier, FullProbKM, NelsonAalen, KaplanMeierZeroAfter
from survivors.metrics import ibs_WW, auprc
from survivors.constants import get_y

""" Auxiliary functions """


@njit('f4(f4[:], f4[:], f4[:], f4[:], u4, f4[:])', cache=True)
def lr_hist_statistic(time_hist_1, time_hist_2, cens_hist_1, cens_hist_2,
                      weightings, obs_weights):
    N_1_j = np.cumsum(time_hist_1[::-1])[::-1]
    N_2_j = np.cumsum(time_hist_2[::-1])[::-1]
    ind = np.where((cens_hist_1 + cens_hist_2 != 0) & (N_1_j * N_2_j != 0))[0]
#     ind = np.where((cens_hist_1 + cens_hist_2 != 0) & (N_1_j + N_2_j > 0))[0]
    if ind.shape[0] == 0:
        return 0.0

    N_1_j = N_1_j[ind]
    N_2_j = N_2_j[ind]
    O_1_j = cens_hist_1[ind]
    O_2_j = cens_hist_2[ind]

    N_j = N_1_j + N_2_j
    O_j = O_1_j + O_2_j
    E_1_j = N_1_j * O_j / N_j
    
    res = np.zeros((N_j.shape[0], 3), dtype=np.float32)
    res[:, 1] = O_1_j - E_1_j
    res[:, 2] = E_1_j * (N_j - O_j) * N_2_j / (N_j * (N_j))  # N_j
    
    res[:, 0] = 1.0
    if weightings == 2:
        res[:, 0] = N_j
    elif weightings == 3:
        res[:, 0] = np.sqrt(N_j)
    elif weightings == 4:
        res[:, 0] = np.cumprod((1.0 - O_j / (N_j + 1)))
    elif weightings == 5:
        res[:, 0] = obs_weights[ind]
    elif weightings == 6:
        res[:, 0] = O_j/N_j
    elif weightings == 7:
        res[:, 0] = np.cumprod((1.0 - O_j / (N_j + 1)))
    elif weightings == 8:
        res[:, 0] = N_j/(N_1_j*N_2_j)
#     var = np.trapz((res[:, 0] * res[:, 0] * res[:, 2]), bins[ind])
#     num = np.trapz((res[:, 0] * res[:, 1]), bins[ind])
    var = (res[:, 0] * res[:, 0] * res[:, 2]).sum()
    num = (res[:, 0] * res[:, 1]).sum()
    
    if var == 0:
        return 0
    stat_val = np.power(num, 2) / var

    if weightings == 7:
        res[:, 0] = 1 - res[:, 0]
        stat_val2 = np.power((res[:, 0] * res[:, 1]).sum(), 2) / ((res[:, 0] * res[:, 0] * res[:, 2]).sum())
        stat_val = max(stat_val, stat_val2)
    return stat_val

def weight_hist_stat(time_hist_1, time_hist_2, cens_hist_1=None, cens_hist_2=None, 
                     weights_hist=None, weightings=""):
    try:
        if cens_hist_1 is None:
            cens_hist_1 = time_hist_1
        if cens_hist_2 is None:
            cens_hist_2 = time_hist_2
        if weights_hist is None:
            weights_hist = np.ones_like(time_hist_1)
        d = {"logrank": 1, "wilcoxon": 2, "tarone-ware": 3, "peto": 4, "weights": 5}
        d.update({"diff": 6, "maxcombo": 7, "frac": 8})
        weightings = d.get(weightings, 1)
        logrank = lr_hist_statistic(time_hist_1.astype("float32"),
                                    time_hist_2.astype("float32"),
                                    cens_hist_1.astype("float32"),
                                    cens_hist_2.astype("float32"),
                                    np.uint32(weightings),
                                    weights_hist.astype("float32")
                                   )
        return logrank
    except Exception as err:
        print(err)
        return 0.0

    
def optimal_criter_split_hist(left_time_hist, left_cens_hist,
                              right_time_hist, right_cens_hist,
                              na_time_hist, na_cens_hist, weights_hist, criterion, dis_coef, 
                              apr_t_distr, apr_e_distr, l_reg):
    none_to = 0
    max_stat_val = 1.0
    
#     n1 = np.cumsum(left_time_hist[::-1])[::-1] + 1
#     n2 = np.cumsum(right_time_hist[::-1])[::-1] + 1
#     cf = n1/(n1 + n2)
    
    n1 = np.sum(left_time_hist)
    n2 = np.sum(right_time_hist)
    cf = n1/(n1 + n2)
    
#     ev = np.sum(left_cens_hist + right_cens_hist)
#     ev_cf = ev/(n1 + n2)
    
    if na_time_hist.shape[0] > 0:
        a = weight_hist_stat(left_time_hist + na_time_hist + l_reg*apr_t_distr*cf, 
                             right_time_hist + l_reg*apr_t_distr*(1 - cf),
                             left_cens_hist + na_cens_hist + l_reg*apr_e_distr*cf, 
                             right_cens_hist + l_reg*apr_e_distr*(1 - cf),
                             weights_hist, weightings=criterion)
        b = weight_hist_stat(left_time_hist + l_reg*apr_t_distr*cf, 
                             right_time_hist + na_time_hist + l_reg*apr_t_distr*(1 - cf),
                             left_cens_hist + l_reg*apr_e_distr*cf, 
                             right_cens_hist + na_cens_hist + l_reg*apr_e_distr*(1 - cf),
                             weights_hist, weightings=criterion)
        # Nans move to a leaf with maximal statistical value
        none_to = int(a < b)
        max_stat_val = max(a, b)
    else:
        max_stat_val = weight_hist_stat(left_time_hist + l_reg*apr_t_distr*cf,
                                        right_time_hist + l_reg*apr_t_distr*(1 - cf),
                                        left_cens_hist + l_reg*apr_e_distr*cf, 
                                        right_cens_hist + l_reg*apr_e_distr*(1 - cf),
                                        weights_hist, weightings=criterion)
    return (max_stat_val, none_to)


def get_attrs(max_stat_val, values, none_to, l_sh, r_sh, nan_sh):
    attrs = dict()
    attrs["stat_val"] = max_stat_val
    attrs["values"] = values
    if none_to:
        attrs["pos_nan"] = [0, 1]
        attrs["min_split"] = min(l_sh, r_sh + nan_sh)
    else:
        attrs["pos_nan"] = [1, 0]
        attrs["min_split"] = min(l_sh + nan_sh, r_sh)
    return attrs


def transform_woe_np(x_feat, y):
    N_T = y.shape[0]
    N_D = y.sum()
    N_D_ = N_T - N_D
    x_uniq = np.unique(x_feat)
    x_dig = np.digitize(x_feat, x_uniq) - 1

    df_woe_iv = np.vstack([np.bincount(x_dig[y == 0], minlength=x_uniq.shape[0]),
                           np.bincount(x_dig[y == 1], minlength=x_uniq.shape[0])])
    all_0 = df_woe_iv[0].sum()
    all_1 = df_woe_iv[1].sum()

    p_bd = (df_woe_iv[1] + 1e-5) / (N_D + 1e-5)
    p_bd_ = (df_woe_iv[0] + 1e-5) / (N_D_ + 1e-5)
    p_b_d = (all_1 - df_woe_iv[1] + 1e-5) / (N_D + 1e-5)
    p_b_d_ = (all_0 - df_woe_iv[0] + 1e-5) / (N_D_ + 1e-5)

    woe_pl = np.log(p_bd / p_bd_)
    woe_mn = np.log(p_b_d / p_b_d_)
    descr_np = np.vstack([x_uniq, woe_pl - woe_mn])
    features_woe = dict(zip(descr_np[0], descr_np[1]))
    woe_x_feat = np.vectorize(features_woe.get)(x_feat)
    return (woe_x_feat, descr_np)


def get_sa_hists(time, cens, minlength=1, weights=None):
    if time.shape[0] > 0:
        time_hist = np.bincount(time, minlength=minlength)
        cens_hist = np.bincount(time, weights=cens, minlength=minlength)
    else:
        time_hist, cens_hist = np.array([]), np.array([])
    return time_hist, cens_hist


def select_best_split_info(attr_dicts, type_attr, bonf=True, descr_woe=None):
    best_attr = max(attr_dicts, key=lambda x: x["stat_val"])
    
    best_attr["p_value"] = stats.chi2.sf(best_attr["stat_val"], df=1)
    best_attr["sign_split"] = len(attr_dicts)
    if best_attr["sign_split"] > 0:
        best_attr["src_val"] = best_attr['values']
        if type_attr == "cont":
            best_attr["values"] = [f" <= {best_attr['values']}", f" > {best_attr['values']}"]
        elif type_attr == "woe" or type_attr == "categ":
            ind = descr_woe[1] <= best_attr["values"]
            l, r = list(descr_woe[0, ind]), list(descr_woe[0, ~ind])
            best_attr["values"] = [f" in {e}" for e in [l, r]]
        if bonf:
            best_attr["p_value"] *= best_attr["sign_split"]
    return best_attr


def split_time_to_bins(time, apr_times):
    if apr_times is None:
        return np.searchsorted(np.unique(time), time)
#         return np.searchsorted(np.quantile(time, np.arange(6)/5), time)
    return np.searchsorted(np.unique(apr_times), time)
#     return np.searchsorted(np.quantile(apr_times, np.arange(6)/5), time)


def hist_best_attr_split(arr, criterion="logrank", type_attr="cont", weights=None, thres_cont_bin_max=100,
                         signif=1.0, signif_stat=0.0, min_samples_leaf=10, bonf=True, verbose=0, balance=False, 
                         apr_time=None, apr_event=None, l_reg=0, **kwargs):
    best_attr = {"stat_val": signif_stat, "p_value": signif,
                 "sign_split": 0, "values": [], "pos_nan": [1, 0]}
    if arr.shape[1] < 2 * min_samples_leaf:
        return best_attr
    vals = arr[0].astype("float")
    cens = arr[1].astype("uint")
    dur = arr[2].astype("float")
    
    if np.sum(cens) == 0:
        return best_attr
    if weights is None:
        weights = np.ones_like(dur)
        
    weights_hist = None
    dur = split_time_to_bins(dur, apr_time)
    
    if apr_time is None:
        time_bins = np.unique(dur)
        max_bin = dur.max()
        apr_t_distr = np.zeros(max_bin + 1)
        apr_e_distr = np.zeros(max_bin + 1)
    else:
        time_bins = np.unique(apr_time)
        apr_time_1 = split_time_to_bins(apr_time, apr_time)
        max_bin = apr_time_1.max()
        apr_t_distr, apr_e_distr = get_sa_hists(apr_time_1, apr_event, minlength=max_bin + 1)
    
    ind = np.isnan(vals)

    # split nan and not-nan
    dur_notna = dur[~ind]
    cens_notna = cens[~ind]
    vals_notna = vals[~ind]
    weights_notna = weights[~ind]

    dis_coef = 1
    if balance:
        dis_coef = (cens.shape[0] - np.sum(cens)) / np.sum(cens)

    if dur_notna.shape[0] < min_samples_leaf:
        return best_attr

    descr_woe = None
    if type_attr == "woe" or type_attr == "categ":
        vals_notna, descr_woe = transform_woe_np(vals_notna, cens_notna)

    # find splitting values
    uniq_set = np.unique(vals_notna)
    if uniq_set.shape[0] > thres_cont_bin_max:
        uniq_set = np.quantile(vals_notna, [i / float(thres_cont_bin_max) for i in range(1, thres_cont_bin_max)])
    else:
        uniq_set = (uniq_set[:-1] + uniq_set[1:]) * 0.5
    uniq_set = np.unique(np.round(uniq_set, 3))

    index_vals_bin = np.digitize(vals_notna, uniq_set, right=True)

    # find global hist by times
    na_time_hist, na_cens_hist = get_sa_hists(dur[ind], cens[ind],
                                              minlength=max_bin + 1, weights=weights[ind])

    r_time_hist, r_cens_hist = get_sa_hists(dur_notna, cens_notna,
                                            minlength=max_bin + 1, weights=weights_notna)
    l_time_hist = np.zeros_like(r_time_hist, dtype=np.float32)
    l_cens_hist = l_time_hist.copy()
    
    num_nan = ind.sum()
    num_r = dur_notna.shape[0]
    num_l = 0

    if criterion == "confident" or criterion == "confident_weights":
        kmf = KaplanMeier()
        if criterion == "confident_weights":
            kmf.fit(dur, cens, weights=weights)
        else:
            kmf.fit(dur, cens)
        ci = kmf.get_confidence_interval_()
        weights_hist = 1 / (ci[1:, 1] - ci[1:, 0] + 1)  # (ci[1:, 1] + ci[1:, 0] + 1e-5)
        criterion = "weights"
    elif criterion == "fullprob":
        kmf = FullProbKM()
        kmf.fit(dur, cens)
        weights_hist = kmf.survival_function_at_times(np.unique(dur))
        criterion = "weights"
    elif criterion == "ibswei":
        kmf = KaplanMeierZeroAfter()
        dur_ = arr[2].copy()
        kmf.fit(dur_, cens)

        dd = np.unique(dur_)
        sf = kmf.survival_function_at_times(dd)
        sf = np.repeat(sf[np.newaxis, :], dd.shape[0], axis=0)

        y_ = get_y(cens=np.ones_like(dd), time=dd)
        y_["cens"] = True
        ibs_ev = ibs_WW(y_, y_, sf, dd, axis=0)
        y_["cens"] = False
        ibs_cn = ibs_WW(y_, y_, sf, dd, axis=0)

        ratio = np.sum(cens)/cens.shape[0]
        weights_hist = ibs_ev*ratio + ibs_cn*(1-ratio)
        criterion = "weights"
    elif criterion == "T-ET":
        kmf = KaplanMeierZeroAfter()
        dur_ = arr[2].copy()
        kmf.fit(dur_, cens)

        dd = np.unique(dur_)
        ET = np.trapz(kmf.survival_function_at_times(dd), dd)
        weights_hist = (dd - ET)  # **2
        criterion = "weights"
    elif criterion == "kde":
        na = NelsonAalen()
        na.fit(dur, cens, np.ones(len(dur)))
        weights_hist = na.get_smoothed_hazard_at_times(np.unique(dur))
        criterion = "weights"
    elif criterion == "weights":
        weights_hist = np.bincount(dur, weights=weights,
                                   minlength=max_bin + 1)
        weights_hist = np.cumsum(weights_hist[::-1])[::-1]  # np.sqrt()

        weights_hist = weights_hist / weights_hist.sum()

    # for each split values get branches
    attr_dicts = []
    
    for u in np.unique(index_vals_bin):
        curr_mask = index_vals_bin == u
        curr_n = curr_mask.sum()
        curr_time_hist, curr_cens_hist = get_sa_hists(dur_notna[curr_mask], cens_notna[curr_mask],
                                                      minlength=max_bin + 1, weights=weights_notna[curr_mask])
        l_time_hist += curr_time_hist
        l_cens_hist += curr_cens_hist
        r_time_hist -= curr_time_hist
        r_cens_hist -= curr_cens_hist
        num_l += curr_n
        num_r -= curr_n

        if min(num_l, num_r) <= min_samples_leaf:
            continue
            
#         plt.plot(l_time_hist)
#         plt.plot(r_time_hist)
#         plt.plot(apr_t_distr)
#         plt.show()
        
        max_stat_val, none_to = optimal_criter_split_hist(
            l_time_hist, l_cens_hist, r_time_hist, r_cens_hist,
            na_time_hist, na_cens_hist, weights_hist, criterion, dis_coef, 
            apr_t_distr, apr_e_distr, l_reg)
        
        if max_stat_val > signif_stat:
            attr_loc = get_attrs(max_stat_val, uniq_set[u], none_to, num_l, num_r, num_nan)
            attr_dicts.append(attr_loc)
            
    if len(attr_dicts) == 0:
        return best_attr
    best_attr = select_best_split_info(attr_dicts, type_attr, bonf, descr_woe=descr_woe)
    
    if verbose > 0:
        print(best_attr["p_value"], len(uniq_set))
    return best_attr

In [187]:
from survivors.tree.node import Node, Rule
from survivors.tree import CRAID

class Node1(Node):
    def find_best_split(self):
        numb_feats = self.info["max_features"]
        numb_feats = np.clip(numb_feats, 1, len(self.features))
        n_jobs = min(numb_feats, self.info["n_jobs"])

        selected_feats = list(np.random.choice(self.features, size=numb_feats, replace=False))
        args = self.get_comb_fast(selected_feats)

        ml = np.vectorize(lambda x: hist_best_attr_split(**x))(args)
        attrs = {f: ml[ind] for ind, f in enumerate(selected_feats)}
        attr = max(attrs, key=lambda x: attrs[x]["stat_val"])
        
        return (attr, attrs[attr])
    
    def split(self):
        node_edges = np.array([], dtype=int)
        self.rule_edges = np.array([], dtype=Rule)
        
        attr, best_split = self.find_best_split()
        
        # The best split is not significant
        if best_split["sign_split"] == 0:
            if self.verbose > 0:
                print(f'Конец ветви, незначащее p-value: {best_split["p_value"]}')
            return node_edges
        if self.verbose > 0:
            print('='*6, best_split["p_value"], attr)

        branch_ind = self.ind_for_nodes(self.df[attr], best_split, attr in self.categ)

        for n_b in np.unique(branch_ind):
            rule = Rule(feature=attr,
                        condition=best_split["values"][n_b],
                        has_nan=best_split["pos_nan"][n_b])
            d_node = self.df[branch_ind == n_b].copy()
            N = Node1(df=d_node, full_rule=self.full_rule + [rule],
                     features=self.features, categ=self.categ,
                     depth=self.depth + 1, verbose=self.verbose, **self.info)
            node_edges = np.append(node_edges, N)
            self.rule_edges = np.append(self.rule_edges, rule)

        if self.rule_edges.shape[0] == 1:
            print(branch_ind, self.df[attr], best_split, attr in self.categ)
            raise ValueError('ERROR: Only one branch created!')

        return node_edges

class CRAID1(CRAID):
    def fit(self, X, y):
        if len(self.features) == 0:
            self.features = X.columns
        self.bins = cnt.get_bins(time=y[cnt.TIME_NAME])  # , cens = y[cnt.CENS_NAME])
        X = X.reset_index(drop=True)
        X_tr = X.copy()
        X_tr[cnt.CENS_NAME] = y[cnt.CENS_NAME].astype(np.int32)
        X_tr[cnt.TIME_NAME] = y[cnt.TIME_NAME].astype(np.float32)

        if not ("min_samples_leaf" in self.info):
            self.info["min_samples_leaf"] = 0.01
        if isinstance(self.info["min_samples_leaf"], float):
            self.info["min_samples_leaf"] = max(int(self.info["min_samples_leaf"] * X_tr.shape[0]), 1)

        cnt.set_seed(self.random_state)

        if self.balance in ["balance", "balance+correct"]:
            freq = X_tr[cnt.CENS_NAME].value_counts()
            self.correct_proba = freq[1] / (freq[1] + freq[0])  # or freq[1] / (freq[0])

            X_tr = get_oversample(X_tr, target=cnt.CENS_NAME)
        elif self.balance in ["balance+weights"]:
            freq = X_tr[cnt.CENS_NAME].value_counts()

            X_tr["weights_obs"] = np.where(X_tr[cnt.CENS_NAME], freq[0] / freq[1], 1)
            self.info["weights_feature"] = "weights_obs"
        elif self.balance in ["only_log_rank"]:
            self.info["balance"] = True

        if self.cut:
            X_val = X_tr.sample(n=int(0.2 * X_tr.shape[0]), random_state=self.random_state)
            X_tr = X_tr.loc[X_tr.index.difference(X_val.index), :]

        self.nodes[0] = Node1(X_tr, features=self.features, categ=self.categ, **self.info)
        stack_nodes = np.array([0], dtype=int)
        while stack_nodes.shape[0] > 0:
            node = self.nodes[stack_nodes[0]]
            stack_nodes = stack_nodes[1:]
            if node.depth >= self.depth:
                continue
            sub_nodes = node.split()
            if sub_nodes.shape[0] > 0:
                sub_numbers = np.array([len(self.nodes) + i for i in range(sub_nodes.shape[0])])
                for i in range(sub_nodes.shape[0]):
                    sub_nodes[i].numb = sub_numbers[i]
                self.nodes.update(dict(zip(sub_numbers, sub_nodes)))
                node.set_edges(sub_numbers)
                stack_nodes = np.append(stack_nodes, sub_numbers)

        if self.cut:
            self.cut_tree(X_val, cnt.CENS_NAME, mode_f=roc_auc_score, choose_f=max)

        return

In [188]:
class BootstrapCRAID1(BootstrapCRAID):
    def fit(self, X, y):
        self.features = X.columns
        X = X.reset_index(drop=True)
        X[cnt.CENS_NAME] = y[cnt.CENS_NAME].astype(np.int32)
        X[cnt.TIME_NAME] = y[cnt.TIME_NAME].astype(np.float32)

        self.X_train = X
        self.y_train = y
        self.update_params()

        for i in range(self.n_estimators):
            x_sub = self.X_train.sample(n=self.size_sample, replace=self.bootstrap, random_state=i)
            x_oob = self.X_train.loc[self.X_train.index.difference(x_sub.index), :]

            x_sub = x_sub.reset_index(drop=True)
            X_sub_tr, y_sub_cr = cnt.pd_to_xy(x_sub)

            model = CRAID1(features=self.features, random_state=i, **self.tree_kwargs)
            model.fit(X_sub_tr, y_sub_cr)

            self.add_model(model, x_oob)
        print(f"fitted: {len(self.models)} models.")

In [189]:
from survivors.ensemble import BoostingCRAID

class IBSCleverBoostingCRAID1(BoostingCRAID):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.name = "IBSCleverBoostingCRAID"

    def fit(self, X, y):
        self.features = X.columns
        X = X.reset_index(drop=True)
        X[cnt.CENS_NAME] = y[cnt.CENS_NAME].astype(np.int32)
        X[cnt.TIME_NAME] = y[cnt.TIME_NAME].astype(np.float32)

        self.X_train = X
        self.X_train["ind_start"] = self.X_train.index
        self.y_train = y

        self.weights = np.ones(self.X_train.shape[0], dtype=float)
        self.bettas = []
        self.l_ibs = []
        self.l_weights = []
        self.update_params()

        for i in range(self.n_estimators):
            x_sub = self.X_train.sample(n=self.size_sample, 
                                        # weights=self.weights,
                                        replace=self.bootstrap, random_state=i)
            
            x_oob = self.X_train.loc[self.X_train.index.difference(x_sub.index), :]
            print(f"UNIQUE ({i}):{np.unique(x_sub.index).shape[0]}, DIST:", np.bincount(x_sub["cens"]))
            x_sub = x_sub.reset_index(drop=True)
            X_sub_tr, y_sub_tr = cnt.pd_to_xy(x_sub)
            if self.weighted_tree:
                X_sub_tr["weights_obs"] = self.weights[x_sub['ind_start']]
            
            model = CRAID1(features=self.features, apr_time=y_sub_tr["time"].copy(), apr_event=y_sub_tr["cens"].copy(),
                           random_state=i, **self.tree_kwargs)
            model.fit(X_sub_tr, y_sub_tr)

            wei_i, betta_i = self.count_model_weights(model, X_sub_tr, y_sub_tr)
            self.add_model(model, x_oob, wei_i, betta_i)
            self.update_weight(x_sub['ind_start'], wei_i)

    def predict(self, x_test, aggreg=True, **kwargs):
        res = []
        weights = []
        for i in range(len(self.models)):
            res.append(self.models[i].predict(x_test, **kwargs))

        res = np.array(res)
        weights = None
        if aggreg:
            res = self.get_aggreg(res, weights)
        return res

    def predict_at_times(self, x_test, bins, aggreg=True, mode="surv"):
        res = []
        weights = []
        for i in range(len(self.models)):
            res.append(self.models[i].predict_at_times(x_test, bins=bins,
                                                       mode=mode))

        res = np.array(res)
        weights = None
        if aggreg:
            res = self.get_aggreg(res, weights)
            if mode == "surv":
                res[:, -1] = 0
                res[:, 0] = 1
        return res

    def count_model_weights(self, model, X_sub, y_sub):
        if self.all_weight:
            X_sub = self.X_train
            y_sub = self.y_train
        pred_sf = model.predict_at_times(X_sub, bins=self.bins, mode="surv")
        
        ibs_sf = metr.auprc(self.y_train, y_sub, pred_sf, self.bins, axis=0)
        betta = np.mean(ibs_sf)
        wei = 1 - ibs_sf
        return wei, abs(betta)

    def update_weight(self, index, wei_i):
        # self.weights += wei_i
        pass

    def get_aggreg(self, x, wei=None):
        if self.aggreg_func == 'median':
            return np.median(x, axis=0)
        elif self.aggreg_func == "wei":
            if wei is None:
                wei = np.array(self.bettas)
            wei = wei / np.sum(wei)
            return np.sum((x.T * wei).T, axis=0)
        elif self.aggreg_func == "argmean":
            wei = np.where(np.argsort(np.argsort(wei, axis=1), axis=1) > len(self.bettas)//2, 1, 0)
            wei = wei / np.sum(wei, axis=1).reshape(-1, 1)
            return np.sum((x.T * wei).T, axis=0)
        elif self.aggreg_func == "argwei":
            wei = np.where(np.argsort(np.argsort(wei, axis=1), axis=1) > len(self.bettas)//2, 1/np.array(self.bettas), 0)
            wei = wei / np.sum(wei, axis=1).reshape(-1, 1)
            return np.sum((x.T * wei).T, axis=0)
        return np.mean(x, axis=0)

    def plot_curve(self, X_tmp, y_tmp, bins, label="", metric="ibs"):
        res = []
        metr_vals = []
        for i in range(len(self.models)):
            res.append(self.models[i].predict_at_times(X_tmp, bins=bins, mode="surv"))

            res_all = np.array(res)
            res_all = self.get_aggreg(res_all, np.array(self.bettas)[:i+1])
            res_all[:, -1] = 0
            res_all[:, 0] = 1
            if metric == "ibs":
                metr_vals.append(metr.ibs_WW(self.y_train, y_tmp, res_all, bins))
            else:
                metr_vals.append(metr.auprc(self.y_train, y_tmp, res_all, bins))
        plt.plot(range(len(self.models)), metr_vals, label=label)

In [190]:
from sklearn.model_selection import train_test_split
from survivors.experiments.grid import generate_sample, prepare_sample, count_metric

# X, y, features, categ, sch_nan = ds.load_gbsg_dataset()
X, y, features, categ, sch_nan = ds.load_wuhan_dataset()
# # y["time"] += 1

# # y["cens"] = ~y["cens"]
# features = list(set(features) - {"max_2019_nCoV_nucleic_acid_detection", 
#                                  "mean_2019_nCoV_nucleic_acid_detection", 
#                                  "min_2019_nCoV_nucleic_acid_detection"})
# X = X[features]

X_TR, X_HO = train_test_split(X, stratify=y[cnt.CENS_NAME],
                              test_size=0.33, random_state=42)
X_tr, y_tr, X_HO, y_HO, bins_HO = prepare_sample(X, y, X_TR.index, X_HO.index)

df = X_HO.copy()
df["time"] = y_HO["time"]
df["cens"] = y_HO["cens"]

  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.m

  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  

  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df

  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df

  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df

  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  

  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  

  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  mapped = lib.map_infer(
  df_agg['min_' + c] = df_agg[c].apply(np.nanmin)
  mapped = lib.map_infer(
  df_agg['max_' + c] = df_agg[c].apply(np.nanmax)
  mapped = lib.map_infer(
  df_agg['mean_' + c] = df_agg[c].apply(np.nanmean)
  

  df_agg['time'] = df_agg.loc[:, ['Admission time', 'Discharge time']].apply(lambda x: (x['Discharge time'] - x['Admission time']).days, axis=1)


In [192]:
from survivors.ensemble import BootstrapCRAID
param_bstr = {'balance': None, 'categ': categ, 
        'criterion': 'peto', 'depth': 10, 'ens_metric_name': 'IBS_REMAIN', 
        'leaf_model': 'base_zero_after', 'max_features': 0.3, 'min_samples_leaf': 0.01, # 0.01 
        'n_estimators': 50, 'n_jobs': 5, 'size_sample': 0.7}


bstr = BootstrapCRAID1(**param_bstr)
bstr.fit(X_tr, y_tr)
bstr.tolerance_find_best(param_bstr["ens_metric_name"])
pred_time = bstr.predict(X_HO, target="time")
pred_surv = bstr.predict_at_times(X_HO, bins=bins_HO, mode="surv")
pred_haz = bstr.predict_at_times(X_HO, bins=bins_HO, mode="hazard")

print(count_metric(y_tr, y_HO, pred_time,
                   pred_surv, pred_haz, bins_HO, 
                   ['CI', "IBS_REMAIN", "BAL_IBS_REMAIN", "IAUC_WW_TI", "AUPRC", "EVENT_AUPRC", "BAL_AUPRC"]))

fitted: 50 models.
[0.1429 0.1375 0.1032 0.1005 0.094  0.0957 0.0878 0.0884 0.0864 0.0851
 0.0868 0.0832 0.082  0.0818 0.079  0.0786 0.0788 0.08   0.0786 0.0779
 0.0767 0.0773 0.0774 0.0772 0.0764 0.0764 0.076  0.0763 0.0766 0.0765
 0.0758 0.0767 0.0766 0.077  0.0767 0.0763 0.0763 0.0764 0.0763 0.0764
 0.0765 0.0766 0.0765 0.0756 0.0755 0.0755 0.0754 0.0756 0.0758 0.0757]
fitted: 47 models.
[0.75551748 0.0749408  0.14910095 0.85885689 0.7424502  0.52176386
 0.72907527]


  false_pos = cumsum_fp / n_controls


In [None]:
[0.62778972 0.14300061 0.33354884 0.77477676 0.69121148 0.50949231
 0.6706261 ]

In [None]:
[0.6281024  0.14091349 0.34395037 0.77027241 0.69435691 0.52027288
 0.67463645]

In [None]:
[0.62614813 0.14374146 0.34480665 0.75541951 0.69278564 0.51871816
 0.67306706]

In [None]:
[0.63091655 0.1422004  0.34096467 0.76252518 0.69420142 0.52163777
 0.67465319]

In [None]:
[0.63162009 0.14071645 0.34137235 0.76280546 0.69149669 0.51923545
 0.67198272]

In [199]:
from survivors.ensemble import BootstrapCRAID
param_bstr = {'aggreg_func': 'mean', 'all_weight': True, 'balance': "only_log_rank", 
              'categ': categ, "l_reg": 0.01,
              'criterion': 'peto', 'depth': 10, 'ens_metric_name': 'IBS_REMAIN', 
              'leaf_model': 'base_zero_after', 'max_features': 0.3, 'min_samples_leaf': 0.001, # 0.01 
              'n_estimators': 50, 'n_jobs': 5, 'size_sample': 0.7}

bstr = IBSCleverBoostingCRAID1(**param_bstr)
bstr.fit(X_tr, y_tr)
bstr.tolerance_find_best(param_bstr["ens_metric_name"])
pred_time = bstr.predict(X_HO, target="time")
pred_surv = bstr.predict_at_times(X_HO, bins=bins_HO, mode="surv")
pred_haz = bstr.predict_at_times(X_HO, bins=bins_HO, mode="hazard")

print(count_metric(y_tr, y_HO, pred_time,
                   pred_surv, pred_haz, bins_HO, 
                   ['CI', "IBS_REMAIN", "BAL_IBS_REMAIN", "IAUC_WW_TI", "AUPRC", "EVENT_AUPRC", "BAL_AUPRC"]))

UNIQUE (0):120, DIST: [93 82]
UNIQUE (1):128, DIST: [97 78]
UNIQUE (2):126, DIST: [89 86]
UNIQUE (3):127, DIST: [94 81]
UNIQUE (4):128, DIST: [97 78]
UNIQUE (5):126, DIST: [93 82]
UNIQUE (6):119, DIST: [94 81]
UNIQUE (7):128, DIST: [97 78]
UNIQUE (8):128, DIST: [95 80]
UNIQUE (9):121, DIST: [96 79]
UNIQUE (10):134, DIST: [87 88]
UNIQUE (11):118, DIST: [99 76]
UNIQUE (12):126, DIST: [83 92]
UNIQUE (13):131, DIST: [97 78]
UNIQUE (14):123, DIST: [90 85]
UNIQUE (15):132, DIST: [94 81]
UNIQUE (16):128, DIST: [112  63]
UNIQUE (17):128, DIST: [108  67]
UNIQUE (18):121, DIST: [90 85]
UNIQUE (19):130, DIST: [89 86]
UNIQUE (20):127, DIST: [93 82]
UNIQUE (21):131, DIST: [106  69]
UNIQUE (22):131, DIST: [100  75]
UNIQUE (23):121, DIST: [105  70]
UNIQUE (24):130, DIST: [91 84]
UNIQUE (25):125, DIST: [100  75]
UNIQUE (26):124, DIST: [100  75]
UNIQUE (27):133, DIST: [90 85]
UNIQUE (28):123, DIST: [90 85]
UNIQUE (29):127, DIST: [97 78]
UNIQUE (30):121, DIST: [96 79]
UNIQUE (31):126, DIST: [102  73]
UN

  false_pos = cumsum_fp / n_controls


In [None]:
[0.75517478 0.07205391 0.13473618 0.86944936 0.74515232 0.52232319
 0.73164753]

In [None]:
[0.75551748 0.0749408  0.14910095 0.85885689 0.7424502  0.52176386
 0.72907527]

In [None]:
[0.62802423 0.14211188 0.34170921 0.75739075 0.69467389 0.51757682
 0.67461211]

In [None]:
[0.63107289 0.14144822 0.3462126  0.77005228 0.6896938  0.51910505
 0.67036929]

In [None]:
[0.63056478 0.14079638 0.34322143 0.76938389 0.69300919 0.51992004
 0.67340143]

In [None]:
[0.63146375 0.14081558 0.34032002 0.76382246 0.6913002  0.51914928
 0.67179873]

In [None]:
[0.62364667 0.14340442 0.34647329 0.77301765 0.69317095 0.51897566
 0.67343789]

In [None]:
[0.63306625 0.13905343 0.37198661 0.78508546 0.68627601 0.52907648
 0.66846826]

In [None]:
[0.6214188  0.14391302 0.35822151 0.76261641 0.68925319 0.51655777
 0.66969004]

In [None]:
[0.63162009 0.14071645 0.34137235 0.76280546 0.69149669 0.51923545
 0.67198272]