In [1]:
import pandas as pd
import numpy as np
import os
import time
import copy
import pathlib, tempfile

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from graphviz import Digraph
from joblib import Parallel, delayed

from survivors.tree.find_split import best_attr_split

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder
from sksurv.linear_model import CoxPHSurvivalAnalysis

from survivors import metrics as metr
from survivors import constants as cnt

from survivors.tree import CRAID

%load_ext line_profiler

In [2]:
"""" Auxiliary functions """
def join_dict(a, b):
    return dict(list(a.items()) + list(b.items()))


class LeafModel(object):
    def __init__(self):
        self.survival = None
        self.hazard = None
        self.features_mean = dict()

    def fit(self, X_node):
        self.survival = metr.get_survival_func(X_node[cnt.TIME_NAME], X_node[cnt.CENS_NAME])
        self.hazard = metr.get_hazard_func(X_node[cnt.TIME_NAME], X_node[cnt.CENS_NAME])
        self.features_mean = X_node.mean(axis=0).to_dict()

    def predict_mean_feature(self, X, feature_name):
        return self.features_mean[feature_name]

    def predict_survival_at_times(self, X, bins):
        return self.survival.survival_function_at_times(bins).to_numpy()

    def predict_hazard_at_times(self, X, bins):
        return self.survival.cumulative_hazard_at_times(bins).to_numpy()


""" Класс вершины дерева решений """
class Node(object):
    """
    Node of decision tree.
    Allow to separate data into 2 subnodes (references store in edges) 

    Attributes
    ----------
    df : Pandas DataFrame
        Data of Node
    numb : int
        Number or name of Node
    depth : int
        Distance from root node
    edges : array-like
        Subbranches after separating
    features : list
        Available features
    categ : list
        Names of categorical features
    woe : boolean
        Mode of categorical preparation
    rule : dict
        Allow to define data to node.
        name: str
            query in pandas terms
        attr: str
            feature of separation
        pos_nan: int
            Indicator of nan
    is_leaf : boolean
        True if node don't have subnodes
    verbose : int
        Print best split of node
    info : dict
        Parameters for finding the best split

    Methods
    -------
    check_params : Fill empty parameters and map max_features to int
    find_best_split : Choose best split of node according to parameters
    split : Try to create subnodes by best split
    get_df_node : Return data for node
    set_leaf : Delete subnodes and reset data
    
    predict : Return statistic values of data
    predict_rules : Return full rules from node to leaf
    predict_scheme : Return all possible outcomes for additional features determination
    
    get_figure : Create picture of data (hist, survival function)
    get_rule : Return rule of node
    get_description : Return common values of data (size, depth, death, cens)
    build_viz : Create and fill graphviz digraph
    translate : Replace rules and features by dictionary
    
    """
    __slots__ = ("df", "numb",
                 "depth", "edges", "features", "leaf_model",
                 "categ", "woe", "rule", "is_leaf", "verbose", "info")

    def __init__(self, df,  numb=1, depth=0,
                 features=[], categ=[], woe=False,
                 rule={"name": "", "attr": "", "pos_nan": 0},
                 verbose=0, **info):
        self.df = df
        self.numb = numb
        self.depth = depth
        self.edges = np.array([])
        self.features = features
        self.categ = categ
        self.woe = woe
        self.rule = rule
        self.is_leaf = True
        self.verbose = verbose
        self.info = info
        self.leaf_model = LeafModel()
        self.check_params()
    
    def check_params(self):
        self.info.setdefault("bonf", True)
        self.info.setdefault("n_jobs", 16)
        self.info.setdefault("max_features", 1.0)
        self.info.setdefault("signif", 1.1)
        self.info.setdefault("thres_cont_bin_max", 100)
        if self.info["max_features"] == "sqrt":
            self.info["max_features"] = int(np.trunc(np.sqrt(len(self.features))+0.5))
        elif isinstance(self.info["max_features"], float):
            self.info["max_features"] = int(self.info["max_features"]*len(self.features))
        self.leaf_model.fit(self.df)

    """ GROUP FUNCTIONS: CREATE LEAFS """
    
    def find_best_split(self):
        numb_feats = self.info["max_features"]
        numb_feats = np.clip(numb_feats, 1, len(self.features))
        n_jobs = min(numb_feats, self.info["n_jobs"])
        selected_feats = np.random.choice(self.features, size=numb_feats, replace=False)
        
        args = np.array([])
        for feat in selected_feats:
            t = self.info.copy()
            t["type_attr"] = "woe" if self.woe else "categ" if feat in self.categ else "cont"
            t["arr"] = self.df.loc[:, [feat, cnt.CENS_NAME, cnt.TIME_NAME]].to_numpy().T
            args = np.append(args, t)
        with Parallel(n_jobs=n_jobs, verbose=0, batch_size=10) as parallel:
            ml = parallel(delayed(best_attr_split)(**a) for a in args)

        attrs = {f: ml[ind] for ind, f in enumerate(selected_feats)}
        attr = min(attrs, key=lambda x: attrs[x]["p_value"])
        
        if attrs[attr]["sign_split"] > 0 and self.info["bonf"]:
            attrs[attr]["p_value"] = attrs[attr]["p_value"] / attrs[attr]["sign_split"]
        return (attr, attrs[attr])
        
    def split(self):
        node_edges = np.array([], dtype = object)
        attr, best_split = self.find_best_split()
        # В лучшем признаке не было ни одного значимого разбиения
        if best_split["sign_split"] == 0:
            if self.verbose > 0:
                print(f'Конец ветви, незначащее p-value: {best_split["p_value"]}')
            return (attr, best_split)
        
        if self.verbose > 0:
            print('='*6, best_split["p_value"], attr)
        leaf_ind = 0
        for v, p_n in zip(best_split["values"], best_split["pos_nan"]):
            query = attr + v
            if p_n == 1:
                query = "(" + attr + v + ") or (" + attr + " != " + attr + ")"
            d_node = self.df.query(query).copy()
            N = Node(df=d_node,
                     features=self.features, categ=self.categ, depth=self.depth+1, 
                     rule={"name": attr + v, "attr": attr, "pos_nan": p_n},
                     verbose=self.verbose, **self.info)
            node_edges = np.append(node_edges, N)
        return node_edges
        
    def set_edges(self, edges):
        self.edges = edges
        self.is_leaf = False
        self.df = None
        
    """ GROUP FUNCTIONS: CLEAR AND DEL """
    def get_df_node(self):
        if self.is_leaf:
            return self.df
        return pd.concat([edge.get_df_node() for edge in self.edges])
    
    def set_leaf(self):
        if self.is_leaf:
            return
        self.df = self.get_df_node()
        del self.edges
        self.edges = np.array([])
        self.is_leaf = True

    def prepare_df_for_attr(self, X):
        attr = self.edges[0].rule['attr']
        if attr not in X.columns:
            X.loc[:, attr] = np.nan
        return X

    def get_split_nan_index(self, X):
        ind_nan = X.index
        for edge in self.edges:
            ind_nan = ind_nan.difference(X.query(edge.rule["name"]).index)
        return ind_nan
        
    """ GROUP FUNCTIONS: PREDICT """
        
    def predict(self, X, target, name_tg="res", bins=None, end_list=[]):
        """
        Return statistic values of data

        Parameters
        ----------
        X : Pandas dataframe
            Contain input features of events.
        target : str or function
            Column name, mode or aggregate function of leaf sample.
            Column name : must be in dataset.columns
                Return mean of feature
            Mode :
                "surv" return survival function
                "hazard" return cumulative hazard function
                "depth" return leafs depth
                "num_node" return leafs numb (names)
        bins : array-like
            Points of timeline
        name_tg : str, optional
            Name of return column. The default is "res".
        end_list : list, optional
            Numbers of end node (instead leaf). The default is [].

        Returns
        -------
        X[name_tg] : array-like 
            Values by target

        """
        if (self.numb in end_list) or self.is_leaf:
            if target == "surv" or target == "hazard":
                if target == "surv":
                    func_at_times = self.leaf_model.predict_survival_at_times(X, bins)  # target(X_node=dataset)
                else:
                    func_at_times = self.leaf_model.predict_hazard_at_times(X, bins)
                X.loc[:, name_tg] = X[name_tg].apply(lambda x: func_at_times)
            elif target == "depth":
                X.loc[:, name_tg] = self.depth
            elif target == "num_node":
                X.loc[:, name_tg] = self.numb
            else:
                dataset = self.get_df_node()
                if target in dataset.columns:
                    X.loc[:, name_tg] = self.leaf_model.predict_mean_feature(X, target)  # np.mean(dataset[target])
        else:
            X = self.prepare_df_for_attr(X)
            ind_nan = self.get_split_nan_index(X)
                
            for edge in self.edges:
                ind = X.query(edge.rule["name"]).index
                if edge.rule["pos_nan"] == 1:
                    ind = ind.append(ind_nan)
                if len(ind) > 0:
                    X.loc[ind, name_tg] = edge.predict(X=X.loc[ind, :], target=target, bins=bins,
                                                       name_tg=name_tg, end_list=end_list)
        return X[name_tg]
    
    def predict_rules(self, X, name_tg="res"):
        if self.is_leaf:
            X.loc[:, name_tg] = self.get_rule()
        else:
            X = self.prepare_df_for_attr(X)
            ind_nan = self.get_split_nan_index(X)

            for edge in self.edges:
                ind = X.query(edge.rule["name"]).index
                if edge.rule["pos_nan"] == 1:
                    ind = ind.append(ind_nan)
                if len(ind) > 0:
                    X.loc[ind, name_tg] = edge.predict_rules(X.loc[ind, :], name_tg)
                    if len(self.rule["name"]) > 0:
                        X.loc[ind, name_tg] = self.get_rule() + '&' + X.loc[ind, name_tg]
        return X[name_tg]
    
    def get_values_column(self, columns):
        return [0, 1]
    
    def predict_scheme(self, X, scheme_feat):
        """
        Return all possible outcomes for additional features determination

        Parameters
        ----------
        X : Pandas dataframe
            Contain input features of events.
        scheme_feat : list
            Features with missing values (relatively).
            If feature in list was used in node, 
                then node consider all possible replaces for value in branches
                Thus, method allow to return all outcomes for different values

        Returns
        -------
        X['res'] : array-like
            For each observation contain dict
                key : rule of overdefined feature (; is separator)
                value : list of sample values
                    censoring flag, time, all values for scheme_feat
                    
        """
        def scheme_output_format(r):
            to_array = lambda col: np.array(self.df.get(col))
            return {r['store_str']:
                    [to_array(cnt.CENS_NAME), 
                     to_array(cnt.TIME_NAME), 
                     # to_array(self.info["sum"]), # TODO FOR SCHEME'S SUM
                     {sch: to_array(sch) for sch in scheme_feat}]}
            
        def join_scheme_leafs(X_sub, ind_nan=[]):
            for edge in self.edges:
                ind = X_sub.query(edge.rule['name']).index
                if len(ind_nan) > 0:
                    if edge.rule["pos_nan"] == 1:
                        ind = ind.append(ind_nan)
                if len(ind) > 0:
                    X_sub.loc[ind, 'tmp'] = edge.predict_scheme(X_sub.loc[ind, :], scheme_feat)
                    X_sub.loc[ind, 'res'] = X_sub.loc[ind, :].apply(lambda r: join_dict(r['res'], r['tmp']), axis=1)
            return X_sub['res']
            
        if self.is_leaf:
            return X.apply(scheme_output_format, axis=1)
        attr = self.edges[0].rule['attr']
        X = self.prepare_df_for_attr(X)
        ind_nan = self.get_split_nan_index(X)
        ind_has = X.index.difference(ind_nan)
        if attr not in scheme_feat:
            X.loc[:, 'res'] = join_scheme_leafs(X, ind_nan)
        else:
            if len(ind_has) > 0:
                X.loc[ind_has, 'res'] = join_scheme_leafs(X.loc[ind_has, :])
            if len(ind_nan) > 0:
                pred_store = X.loc[ind_nan, 'store_str'].copy()
                for val in self.get_values_column(attr):
                    X.loc[ind_nan, attr] = val
                    X.loc[ind_nan, 'store_str'] = pred_store + attr + '==' + str(val) + ';'
                    X.loc[ind_nan, 'res'] = join_scheme_leafs(X.loc[ind_nan, :])
        return X['res']
    
    """ GROUP FUNCTIONS: VISUALIZATION """
    
    def get_figure(self, mode="hist", target=None, save_path=""):
        if len(save_path) > 0:
            plt.ioff()
        fig, ax = plt.subplots(figsize=(8, 6))
        local_df = self.get_df_node()
        if mode == "hist":
            local_df[target].hist(bins=25)
            ax.set_xlim([0, np.max(local_df[target])])
        elif mode == "surv":
            kmf = metr.get_survival_func(local_df[cnt.TIME_NAME], local_df[cnt.CENS_NAME])
            ax.set_xlim([0, np.max(local_df[cnt.TIME_NAME])])
            ax.set_ylim([0, 1])
            plt.xticks(range(0, np.max(local_df[cnt.TIME_NAME])+1, 1000))
            kmf.plot_survival_function(legend=False, fontsize=25)
            # ax.set_xlabel('Время', fontsize=25)
            # ax.set_ylabel('Вероятность выживания', fontsize=25)
            ax.set_xlabel('Time', fontsize=25)
            ax.set_ylabel('Survival probability', fontsize=25)  # plt.xlabel('Timeline', fontsize=0)
        if len(save_path) > 0:
            plt.savefig(save_path)
        else:
            plt.show()
    
    def get_rule(self):
        if not(self.rule["pos_nan"]):
            return f'({self.rule["name"]})'
        # return f'(({self.rule["name"]})|({self.rule["attr"]} != {self.rule["attr"]}))'
        return f'(({self.rule["name"]})| не указано)'
    
    def get_description(self, full=False):
        s = ""  # if not(self.rule["pos_nan"]) else " or " + self.rule["attr"] + " == NaN"
        d = self.get_df_node()
        m_cens = round(d[cnt.CENS_NAME].mean(), 2)
        m_time = round(d[cnt.TIME_NAME].mean(), 2)
        if full:
            label = "\n".join([self.rule["name"] + s,
                               "size = %s" % (d.shape[0]),
                               "cens/size = %s" % (m_cens),
                               "depth = %s" % (self.depth),
                               "death = %s" % (m_time)])
        else:
            label = self.rule["name"] + s 
        return label
        
    def build_viz(self, dot=None, path_dir="", depth=None, **args):
        if dot is None:
            dot = Digraph()
        img_path = path_dir + str(self.numb) + '.png'
        self.get_figure(save_path=img_path, **args)
        dot.node(str(self.numb), label=self.get_description(),
                 image=img_path, fontsize='30')  # fontsize='16'
        if not(depth is None):
            if depth < self.depth:
                return dot
        for ind_e in range(self.edges.shape[0]):
            dot = self.edges[ind_e].build_viz(dot, path_dir, **args)
            dot.edge(str(self.numb), str(self.edges[ind_e].numb))
        return dot
         
    def translate(self, describe):
        if self.is_leaf:
            self.df = self.df.rename(describe, axis=1)
        self.features = [describe.get(f, f) for f in self.features]
        self.categ = [describe.get(c, c) for c in self.categ]
        self.rule["name"] = describe.get(self.rule["name"], self.rule["name"])
        for edge in self.edges:
            edge.translate(describe)


In [3]:
def format_to_pandas(X, columns):
    type_df = type(X)
    if type_df.__name__ == "DataFrame":
        return X.loc[:,columns]
    elif type_df.__name__ == "ndarray":
        return pd.DataFrame(X, columns = columns)
    return None

""" Functions of prunning """

def ols(a,b):
    return sum((a - b)**2)

def find_best_uncut(tree, X, y, target, mode_f, choose_f):
    span_leaf = tree.get_spanning_leaf_numbers()
    d = {}
    for el in span_leaf:
        y_pred = tree.predict(X, target = target, end_list = [el])
        d[el] = round(mode_f(y,y_pred),4)
    
    new_leaf, val = choose_f(d.items(), key = lambda x: x[1])
    tree.delete_leafs_by_span([new_leaf])
    return tree, val
    
def cutted_tree(tree_, X, target, mode_f, choose_f, verbose = 0):
    first_digits = lambda x: float(str(x)[:5])
    y = pd.to_numeric(X[target])
    tree = copy.deepcopy(tree_)
    best_metr = dict()
    best_tree = dict()
    y_pred = tree.predict(X, target = target)
    c = tree.get_leaf_numbers().shape[0]
    
    best_metr[c] = mode_f(y, y_pred)
    best_tree[c] = copy.deepcopy(tree)
    while (len(tree.nodes) > 1):
        tree, val = find_best_uncut(tree, X, y, target, mode_f, choose_f)
        c = tree.get_leaf_numbers().shape[0]
        best_metr[c] = val
        best_tree[c] = copy.deepcopy(tree)
    
    best_metric = first_digits(choose_f(best_metr.values()))
    min_leaf = min([k for k,v in best_metr.items() if first_digits(v) == best_metric])
    
    if verbose > 0:
        plt.clf()
        plt.plot(list(best_metr.keys()), list(best_metr.values()), 'o')
        # plt.plot(list(best_metr.keys()), list(best_metr.values()), 'b')
        plt.xlabel("Количество листов")  # ("Leafs")
        plt.ylabel(f"Лучшее значение метрики {mode_f.__name__}")  # {target}")
        plt.title(f"Обрезка дерева по переменной {target}")
        plt.show()
        print(best_metr)
        print(best_metric, min_leaf)
    
    return best_tree[min_leaf]

In [30]:
"""" Auxiliary functions """
def join_dict(a, b):
    return dict(list(a.items()) + list(b.items()))

def rule_to_string(r):
    s = f"({r['attr']}{r['name']})"
    if r["pos_nan"]:
        s = f"({s}| nan)"  # не указано)"
    return s


class LeafModel(object):
    def __init__(self):
        self.shape = None
        self.survival = None
        self.hazard = None
        self.features_mean = dict()

    def fit(self, X_node, need_features=[cnt.TIME_NAME, cnt.CENS_NAME]):
        self.shape = X_node.shape
        self.default_bins = cnt.get_bins(time=X_node[cnt.TIME_NAME].to_numpy(), 
                                         cens=X_node[cnt.CENS_NAME].to_numpy(), mode='a', num_bins=100)
        self.survival = metr.get_survival_func(X_node[cnt.TIME_NAME], X_node[cnt.CENS_NAME])
        self.hazard = metr.get_hazard_func(X_node[cnt.TIME_NAME], X_node[cnt.CENS_NAME])
        self.features_mean = X_node.mean(axis=0).to_dict()
        self.lists = X_node.loc[:, need_features].to_dict(orient="list")

    def get_shape(self):
        return self.shape
    
    def predict_list_feature(self, feature_name):
        if feature_name in self.lists.keys():
            return self.lists[feature_name]
        return None
        
    def predict_mean_feature(self, X, feature_name):
        return self.features_mean[feature_name]

    def predict_survival_at_times(self, X, bins=None):
        if bins is None:
            bins = self.default_bins
        return self.survival.survival_function_at_times(bins).to_numpy()

    def predict_hazard_at_times(self, X, bins=None):
        if bins is None:
            bins = self.default_bins
        return self.survival.cumulative_hazard_at_times(bins).to_numpy()


class Rule(object):
    def __init__(self, feature : str, condition : str, has_nan : int):
        self.feature = feature
        self.condition = condition
        self.has_nan_ = has_nan
        
    def get_feature(self):
        return self.feature
    
    def get_condition(self):
        return self.condition
    
    def has_nan(self):
        return self.has_nan_
    
    def translate(self, describe):
        self.feature = describe.get(self.feature, self.feature)
    
    def to_str(self):
        s = f"({self.feature}{self.condition})"
        if self.has_nan_:
            s = f"({s}| nan)"  # не указано)"
        return s
        
    
    

""" Класс вершины дерева решений """
class Node(object):
    __slots__ = ("df", "numb",
                 "depth", "edges", "rule_edges", "features", "leaf_model",
                 "categ", "woe", "is_leaf", "verbose", "info")

    def __init__(self, df,  numb=0, depth=0,
                 features=[], categ=[], woe=False,
                 verbose=0, **info):
        self.df = df
        self.numb = numb
        self.depth = depth
        self.edges = np.array([], dtype = object)
        self.rule_edges = np.array([], dtype = object)
        self.features = features
        self.categ = categ
        self.woe = woe
        self.is_leaf = True
        self.verbose = verbose
        self.info = info
        self.leaf_model = LeafModel()
        self.check_params()
    
    def check_params(self):
        self.info.setdefault("bonf", True)
        self.info.setdefault("n_jobs", 16)
        self.info.setdefault("max_features", 1.0)
        self.info.setdefault("signif", 1.1)
        self.info.setdefault("thres_cont_bin_max", 100)
        if self.info["max_features"] == "sqrt":
            self.info["max_features"] = int(np.trunc(np.sqrt(len(self.features))+0.5))
        elif isinstance(self.info["max_features"], float):
            self.info["max_features"] = int(self.info["max_features"]*len(self.features))
        self.leaf_model.fit(self.df)

    """ GROUP FUNCTIONS: CREATE LEAFS """
    
    def find_best_split(self):
        numb_feats = self.info["max_features"]
        numb_feats = np.clip(numb_feats, 1, len(self.features))
        n_jobs = min(numb_feats, self.info["n_jobs"])
        selected_feats = np.random.choice(self.features, size=numb_feats, replace=False)
        
        args = np.array([])
        for feat in selected_feats:
            t = self.info.copy()
            t["type_attr"] = ("woe" if self.woe else "categ") if feat in self.categ else "cont"
            t["arr"] = self.df.loc[:, [feat, cnt.CENS_NAME, cnt.TIME_NAME]].to_numpy().T
            args = np.append(args, t)
        with Parallel(n_jobs=n_jobs, verbose=0, batch_size=10) as parallel:
            ml = parallel(delayed(best_attr_split)(**a) for a in args)

        attrs = {f: ml[ind] for ind, f in enumerate(selected_feats)}
        attr = min(attrs, key=lambda x: attrs[x]["p_value"])
        
        if attrs[attr]["sign_split"] > 0 and self.info["bonf"]:
            attrs[attr]["p_value"] = attrs[attr]["p_value"] / attrs[attr]["sign_split"]
        return (attr, attrs[attr])
        
    def split(self):
        node_edges = np.array([], dtype = object)
        self.rule_edges = np.array([], dtype = object)
        
        attr, best_split = self.find_best_split()
        # The best split is not significant
        if best_split["sign_split"] == 0:
            if self.verbose > 0:
                print(f'Конец ветви, незначащее p-value: {best_split["p_value"]}')
            return node_edges
        
        if self.verbose > 0:
            print('='*6, best_split["p_value"], attr)
        for v, p_n in zip(best_split["values"], best_split["pos_nan"]):
            query = attr + v
            if p_n == 1:
                query = "(" + attr + v + ") or (" + attr + " != " + attr + ")"
            rule = Rule(feature=attr, condition=v, has_nan=p_n)
            d_node = self.df.query(query).copy()
            N = Node(df=d_node, features=self.features, categ=self.categ, 
                     depth=self.depth+1, verbose=self.verbose, **self.info)
            node_edges = np.append(node_edges, N)
            self.rule_edges = np.append(self.rule_edges, rule)
            
        return node_edges
    
    def set_edges(self, edges):
        self.edges = edges
        self.is_leaf = False
        self.df = None
        
    def set_leaf(self):
        if self.is_leaf:
            return
        self.edges = np.array([])
        self.is_leaf = True
    
    def prepare_df_for_attr(self, X):
        attr = self.rule_edges[0].get_feature()
        if attr not in X.columns:
            X.loc[:, attr] = np.nan
        return X[attr].to_numpy()
        
    def get_edges(self, X):
        X_np = self.prepare_df_for_attr(X)
        rule_id = 1 if self.rule_edges[0].has_nan() else 0
        query = self.rule_edges[rule_id].get_condition()
        if self.rule_edges[0].get_feature() in self.categ:
            values = np.isin(X_np, eval(query[query.find("["):]))
        else:
            values = eval("X_np" + query)
        return np.where(values, self.edges[rule_id], self.edges[1-rule_id])
        
    def predict(self, X, target, bins=None):
        if target == "surv" or target == "hazard":
            if target == "surv":
                func_at_times = self.leaf_model.predict_survival_at_times(X, bins)  # target(X_node=dataset)
            else:
                func_at_times = self.leaf_model.predict_hazard_at_times(X, bins)
            X["res"] = X["res"].apply(lambda x: func_at_times)
        elif target == "depth":
            X["res"] = self.depth
        elif target == "num_node":
            X["res"] = self.numb
        else:
            X["res"] = self.leaf_model.predict_mean_feature(X, target)  # np.mean(dataset[target])
        return X["res"]
    
    """ GROUP FUNCTIONS: VISUALIZATION """
    
    def get_figure(self, mode="hist", bins=None, target=cnt.CENS_NAME, save_path=""):
        plt.ioff()
        fig, ax = plt.subplots(figsize=(8, 6))
        if mode == "hist":
            lst = self.leaf_model.predict_list_feature(target)
            plt.hist(lst, bins=25)
            ax.set_xlim([0, np.max(lst)])
            ax.set_xlabel(f'{target}', fontsize=25)
        elif mode == "surv":
            sf = self.leaf_model.predict_survival_at_times(X=None, bins=bins)
            plt.step(bins, sf)
            ax.set_xlabel('Time', fontsize=25)
            ax.set_ylabel('Survival probability', fontsize=25)
        plt.savefig(save_path)
        plt.close(fig)
    
    def get_description(self):
        m_cens = round(self.leaf_model.predict_mean_feature(X=None, feature_name=cnt.CENS_NAME), 2)
        m_time = round(self.leaf_model.predict_mean_feature(X=None, feature_name=cnt.TIME_NAME), 2)
        label = "\n".join([f"size = {self.leaf_model.get_shape()[0]}",
                           f"cens/size = {m_cens}",
                           f"depth = {self.depth}",
                           f"death = {m_time}"])
        return label
        
    def set_dot_node(self, dot, path_dir="", depth=None, **args):
        if not(depth is None) and depth < self.depth :
            return dot
        img_path = path_dir + str(self.numb) + '.png'
        self.get_figure(save_path=img_path, **args)
        dot.node(str(self.numb), label=self.get_description(),
                 image=img_path, fontsize='30')  # fontsize='16'
        return dot
    
    def set_dot_edges(self, dot):
        if not(self.is_leaf):
            for e in range(len(self.rule_edges)):
                s = self.rule_edges[e].to_str()
                dot.edge(str(self.numb), str(self.edges[e]), label=s, fontsize='30')
        return dot
    
    def translate(self, describe):
        if self.is_leaf:
            self.df = self.df.rename(describe, axis=1)
        self.features = [describe.get(f, f) for f in self.features]
        self.categ = [describe.get(c, c) for c in self.categ]
        for e in range(len(self.rule_edges)):
            self.rule_edges[e].translate(describe)

In [198]:
class CRAID1(object):
    def __init__(self, depth = 0,
                 random_state = 123,
                 features = [],
                 categ = [],
                 cut = False,
                 **info):
        self.info = info
        self.cut = cut
        self.remove_files = []
        self.nodes = dict()
        self.depth = depth
        self.features = features
        self.categ = categ
        self.random_state = random_state
        self.name = "CRAID_%s" % (self.random_state)
        self.coxph = None
        self.ohenc = None
        self.bins = []

    def fit(self, X, y):
        if len(self.features) == 0:
            self.features = X.columns
        self.bins = cnt.get_bins(time = y[cnt.TIME_NAME])#, cens = y[cnt.CENS_NAME])
        X = X.reset_index(drop=True)
        X_tr = X.copy()
        X_tr[cnt.CENS_NAME] = y[cnt.CENS_NAME].astype(np.int32)
        X_tr[cnt.TIME_NAME] = y[cnt.TIME_NAME].astype(np.int32)
        
        if not("min_samples_leaf" in self.info):
            self.info["min_samples_leaf"] = 0.01*X_tr.shape[0]
        cnt.set_seed(self.random_state)
        
        if self.cut:
            X_val = X_tr.sample(n = int(0.2*X_tr.shape[0]), random_state=self.random_state)
            X_tr = X_tr.loc[X_tr.index.difference(X_val.index),:]
         
        self.nodes[0] = Node(X_tr, features = self.features, categ = self.categ, **self.info)
        stack_nodes = np.array([0], dtype = int)
        while(stack_nodes.shape[0] > 0):
            node = self.nodes[stack_nodes[0]]
            stack_nodes = stack_nodes[1:]
            if node.depth >= self.depth:
                continue
            sub_nodes = node.split()
            if sub_nodes.shape[0] > 0:
                sub_numbers = np.array([len(self.nodes) + i for i in range(sub_nodes.shape[0])])
                for i in range(sub_nodes.shape[0]):
                    sub_nodes[i].numb = sub_numbers[i]
                self.nodes.update(dict(zip(sub_numbers, sub_nodes)))
                node.set_edges(sub_numbers)
                stack_nodes = np.append(stack_nodes, sub_numbers)
                
        if self.cut:
            self.cut_tree(X_val, cnt.CENS_NAME, mode_f = roc_auc_score, choose_f = max)
        
        self.fit_cox_hazard(X, y)
        self.count_list_rules()
        return
    
    
    def fit_cox_hazard(self, X, y):
        self.coxph = CoxPHSurvivalAnalysis(alpha = 0.1)
        self.ohenc = OneHotEncoder(handle_unknown='ignore')
        pred_node = self.predict(X, mode="target", target = "num_node").to_numpy().reshape(-1,1)
        ohenc_node = self.ohenc.fit_transform(pred_node).toarray()
        self.coxph.fit(ohenc_node, y)
        
    def count_list_rules(self):
        a = {0: []}
        for k_node in sorted(self.nodes.keys()):
            if not(self.nodes[k_node].is_leaf):
                for edge, rule in zip(self.nodes[k_node].edges, self.nodes[k_node].rule_edges):
                    a[edge] = a[k_node] + [rule]
                del a[k_node]
        self.list_rules = a
        return self.list_rules
        
    
    def predict_cox_hazard(self, X, bins):
        bins = np.clip(bins, self.bins.min(), self.bins.max())
        pred_node = self.predict(X, mode="target", target="num_node").to_numpy().reshape(-1,1)
        ohenc_node = self.ohenc.transform(pred_node).toarray()
        hazards = self.coxph.predict_cumulative_hazard_function(ohenc_node)
        pred_haz = np.array(list(map(lambda x: x(bins), hazards)))
        return pred_haz
        
    
    def predict(self, X, mode="target", target=cnt.TIME_NAME, end_list=[], bins=None):
        X = format_to_pandas(X, self.features)
        X.loc[:, "number_node"] = 0
        X.loc[:, "res"] = np.nan
        for i in sorted(self.nodes.keys()):
            ind = X[X["number_node"] == i].index
            if ind.shape[0] > 0:
                if self.nodes[i].is_leaf or (i in end_list):
                    if target == "surv" or target == "hazard":
                        X.loc[ind, "res"] = self.nodes[i].predict(X.loc[ind,:], target, bins)
                    elif mode == "target":
                        X.loc[ind, "res"] = self.nodes[i].predict(X.loc[ind,:], target)
                    elif mode == "scheme":
                        X.loc[ind, "store_str"] = ""
                        X.loc[ind, "res"] = X[ind, "res"].apply(lambda x: dict())
                        X.loc[ind, "res"] = self.tree.predict_scheme(X.loc[ind,:], target)
                    elif mode == "rules":
                        X.loc[ind, "res"] = " & ".join([s.to_str() for s in self.list_rules[i]])
                else:
                    X.loc[ind, "number_node"] = self.nodes[i].get_edges(X.loc[ind,:])
        return X["res"]
    
    def predict_at_times(self, X, bins, mode="surv"):
        """
        Return survival or hazard function.

        Parameters
        ----------
        X : Pandas dataframe
            Contain input features of events.
        bins : array-like
            Points of timeline.
        mode : str, optional
            Type of function. The default is "surv".
            "surv" : send building function in nodes
            "hazard" : send building function in nodes
            "cox-hazard" : fit CoxPH model on node numbers (input)
                                          and time/cens (output)
                       predict cumulative HF from model 

        Returns
        -------
        array-like
            Vector of function values in times (bins).

        """
        X = format_to_pandas(X, self.features)
        if mode == "cox-hazard":
            return self.predict_cox_hazard(X, bins)
        return np.array(self.predict(X, target=mode, bins=bins).to_list())
    
    def predict_schemes(self, X, scheme_feats):
        X = format_to_pandas(X, self.features)
        num_node_to_key = dict(zip(sorted(self.nodes.keys()), range(len(self.nodes))))
        node_bin = np.zeros((X.shape[0], len(self.nodes)), dtype=bool)
        node_bin[:, 0] = 1
        X["res"] = np.nan
        X["res"] = X["res"].apply(lambda x: list())
        for i in sorted(self.nodes.keys()):
            i_num = num_node_to_key[i]
            ind = np.where(node_bin[:, i_num])[0]
            ind_x = X.index[ind]
            if ind.shape[0] > 0:
                if self.nodes[i].is_leaf:
                    X.loc[ind_x, "res"] = X.loc[ind_x, "res"].apply(lambda x: x + [i])
                else:
                    if self.nodes[i].rule_edges[0].get_feature() in scheme_feats:
                        for e in self.nodes[i].edges:
                            node_bin[ind, num_node_to_key[e]] = 1
                    else:
                        pred_edges = self.nodes[i].get_edges(X.iloc[ind,:])
                        for e in set(pred_edges):
                            node_bin[ind, num_node_to_key[e]] = pred_edges == e
        return X["res"]
    
    
    def cut_tree(self, X, target, mode_f=roc_auc_score, choose_f=max):
        """
        Method of prunning tree.
        Find best subtree, which reaches best value of metric "mode_f""

        Parameters
        ----------
        X : Pandas dataframe
            Contain input features of events.
        target : str
            Feature name for metric counting.
        mode_f : function, optional
            Metric for selecting. The default is roc_auc_score.
        choose_f : function, optional
            Type of best value (max or min). The default is max.

        """
        self.nodes = cutted_tree(self, X, target, mode_f, choose_f).nodes
        self.count_list_rules()
    
    def visualize(self, path_dir=None, **kwargs):
        if path_dir is None:
            path_dir = os.getcwd()
        kwargs["bins"] = self.bins
        
        with tempfile.TemporaryDirectory() as tmp_dir:
            dot = Digraph(node_attr={'shape': 'none'})
            ordered_nodes = sorted(self.nodes.keys())
            for i in ordered_nodes:
                dot = self.nodes[i].set_dot_node(dot, path_dir=tmp_dir, **kwargs)
            for i in ordered_nodes:
                dot = self.nodes[i].set_dot_edges(dot)
            dot.render(path_dir + self.name + "_", view=False, format="png")
    
    def translate(self, describe):
        self.features = [describe.get(f,f) for f in self.features]
        self.categ = [describe.get(c,c) for c in self.categ]
        for i in self.nodes.keys():
            self.nodes[i].translate(describe)
    
    def get_leaf_numbers(self):
        return np.array([i for i in self.nodes.keys() if self.nodes[i].is_leaf])
    
    def get_spanning_leaf_numbers(self):
        leafs = self.get_leaf_numbers()
        return np.array([i for i in self.nodes.keys()
                         if np.intersect1d(self.nodes[i].edges, leafs).shape[0] == 2])
    
    def delete_leafs_by_span(self, list_span_leaf):
        deleted_leafs = np.array([], dtype = int)
        for i in list_span_leaf:
            for e in self.nodes[i].edges:
                del self.nodes[e]
            self.nodes[i].set_leaf()

In [199]:
from survivors.datasets import load_pbc_dataset
from survivors.experiments.grid import generate_sample

X, y, features, categ, sch_nan = load_pbc_dataset()
a = generate_sample(X, y, 5)
X_train, y_train, X_test, y_test, bins = next(a)

In [200]:
params = {"criterion": "peto", "depth": 5, "min_samples_leaf": 3, "signif": 0.05, "cut": True}

t_start = time.perf_counter()
craid_tree = CRAID(**params)
craid_tree.fit(X_train, y_train)
print(f"FULL_TIME CRAID: {time.perf_counter() - t_start} seconds")

t_start = time.perf_counter()
craid_tree1 = CRAID1(**params)
craid_tree1.fit(X_train, y_train)
print(f"FULL_TIME CRAID1: {time.perf_counter() - t_start} seconds")

FULL_TIME CRAID: 2.831928449915722 seconds
FULL_TIME CRAID1: 1.637558045098558 seconds


In [201]:
dict(zip(range(len(craid_tree1.nodes)), sorted(craid_tree1.nodes.keys())))

{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 11, 8: 12}

In [202]:
X_ = X_test.copy()
X_["res"] = 0
X_.columns.get_loc("res")

17

In [203]:
X_test

Unnamed: 0,trt,age,sex,ascites,hepato,spiders,edema,bili,chol,albumin,copper,alk,ast,trig,platelet,protime,stage
0,1.0,58.765229,1,1.0,1.0,1.0,1.0,14.5,261.0,2.60,156.0,1718.0,137.95,172.0,190.0,12.2,4.0
1,1.0,56.446270,1,0.0,1.0,1.0,0.0,1.1,302.0,4.14,54.0,7394.8,113.52,88.0,221.0,10.6,3.0
2,1.0,70.072553,0,0.0,0.0,0.0,0.5,1.4,176.0,3.48,210.0,516.0,96.10,55.0,151.0,12.0,4.0
3,1.0,54.740589,1,0.0,1.0,1.0,0.5,1.8,244.0,2.54,64.0,6121.8,60.63,92.0,183.0,10.3,4.0
4,2.0,38.105407,1,0.0,1.0,1.0,0.0,3.4,279.0,3.53,143.0,671.0,113.15,72.0,136.0,10.9,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
124,2.0,47.181383,1,0.0,1.0,0.0,0.0,1.3,316.0,3.51,75.0,1162.0,147.25,137.0,238.0,10.0,4.0
126,2.0,44.104038,1,0.0,0.0,0.0,0.0,0.5,268.0,4.08,9.0,1174.0,86.80,95.0,453.0,10.0,2.0
128,1.0,63.613963,1,0.0,1.0,0.0,0.0,0.9,420.0,3.87,30.0,1009.0,57.35,232.0,,9.7,3.0
131,1.0,40.553046,1,0.0,0.0,0.0,0.0,1.9,448.0,3.83,60.0,1052.0,127.10,175.0,181.0,9.8,3.0


In [206]:
craid_tree1.predict_schemes(X_test, ["bili", "protime"])

0           [3, 4, 5]
1      [3, 4, 11, 12]
2      [3, 4, 11, 12]
3      [3, 4, 11, 12]
4      [3, 4, 11, 12]
            ...      
124    [3, 4, 11, 12]
126    [3, 4, 11, 12]
128    [3, 4, 11, 12]
131    [3, 4, 11, 12]
133    [3, 4, 11, 12]
Name: res, Length: 84, dtype: object

In [115]:
craid_tree1.predict(X_test, mode="rules")

0      ((bili >= 2.35)| nan) & ((protime >= 11.75)| nan)
1      (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
2      (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
3      (bili < 2.35) & ((ascites < 0.5)| nan) & ((bil...
4              ((bili >= 2.35)| nan) & (protime < 11.75)
                             ...                        
124    (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
126    (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
128    (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
131    (bili < 2.35) & ((ascites < 0.5)| nan) & ((bil...
133    (bili < 2.35) & ((ascites < 0.5)| nan) & (bili...
Name: res, Length: 84, dtype: object

In [40]:
craid_tree.predict(X_test, mode="rules")

0      ((bili >= 2.35)| не указано)&((protime >= 11.7...
1      (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
2      (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
3      (bili < 2.35)&((ascites < 0.5)| не указано)&((...
4         ((bili >= 2.35)| не указано)&(protime < 11.75)
                             ...                        
124    (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
126    (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
128    (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
131    (bili < 2.35)&((ascites < 0.5)| не указано)&((...
133    (bili < 2.35)&((ascites < 0.5)| не указано)&(b...
Name: res, Length: 84, dtype: object

In [41]:
%timeit craid_tree.predict(X_test, target="time")
%timeit craid_tree1.predict(X_test, target="time")

25.2 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
7.46 ms ± 5.72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
assert (craid_tree.predict(X_test, target="time") == craid_tree1.predict(X_test, target="time")).all()
assert (craid_tree.predict(X_test, target="cens") == craid_tree1.predict(X_test, target="cens")).all()
assert (craid_tree.predict_at_times(X_test, bins=bins, mode="surv") == craid_tree1.predict_at_times(X_test, bins=bins, mode="surv")).all()
assert (craid_tree.predict_at_times(X_test, bins=bins, mode="cox-hazard") == craid_tree1.predict_at_times(X_test, bins=bins, mode="cox-hazard")).all()


In [None]:
# import graphviz  # doctest: +NO_EXE
# doctest_mark_exe()

# os.environ['PATH'] = os.environ['PATH'] + ';' + r"C:\ProgramData\Anaconda3\envs\survive\Library\bin"
# # craid_tree.visualize(path_dir = "./", mode = "surv", target = "time")

In [8]:
craid_tree1.visualize(mode="surv")

In [125]:
os.getcwd()

'C:\\Users\\vasiliev\\Desktop\\PycharmProjects\\dev-survivors\\demonstration'