In [3]:
import pandas as pd
import numpy as np
import os
import time
import copy
import pathlib, tempfile

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from graphviz import Digraph
from joblib import Parallel, delayed
from scipy import stats

from survivors import metrics as metr
from survivors import constants as cnt
from survivors import criteria as crit
from numba import njit, jit
from lifelines import KaplanMeierFitter, NelsonAalenFitter

%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [None]:
@njit('f8(f8[:], f8[:], i8[:], i8[:], i8, f8[:])', cache=True)
def lr_statistic(dur_1, dur_2, cens_1, cens_2, wei_1, wei_2):
    times = np.unique(np.hstack((dur_1, dur_2)))
    dur_1 = np.searchsorted(times, dur_1) + 1
    dur_2 = np.searchsorted(times, dur_2) + 1
    times_range = np.array([1, times.shape[0]], dtype=np.int32)

    bins = times_range[1] - times_range[0] + 1
    n_1_j = np.histogram(dur_1, bins=bins, range=times_range)[0]
    n_2_j = np.histogram(dur_2, bins=bins, range=times_range)[0]
    O_1_j = np.histogram(dur_1 * cens_1, bins=bins, range=times_range)[0]
    O_2_j = np.histogram(dur_2 * cens_2, bins=bins, range=times_range)[0]

    N_1_j = np.cumsum(n_1_j[::-1])[::-1]
    N_2_j = np.cumsum(n_2_j[::-1])[::-1]
    ind = np.where(N_1_j * N_2_j != 0)
    N_1_j = N_1_j[ind]
    N_2_j = N_2_j[ind]
    O_1_j = O_1_j[ind]
    O_2_j = O_2_j[ind]

    N_j = N_1_j + N_2_j
    O_j = O_1_j + O_2_j
    E_1_j = N_1_j * O_j / N_j
    res = np.zeros((N_j.shape[0], 3), dtype=np.float32)
    res[:, 1] = O_1_j - E_1_j
    res[:, 2] = E_1_j * (N_j - O_j) * N_2_j / (N_j * (N_j - 1))
    res[:, 0] = 1.0
    # if np.any(N_j <= 1):
    #     return 0.0
    if weightings == 2:
        res[:, 0] = N_j
    elif weightings == 3:
        res[:, 0] = np.sqrt(N_j)
    elif weightings == 4:
        res[:, 0] = np.cumprod((1.0 - O_j / (N_j + 1)))
    elif weightings == 5:
        res[:, 0] = np.cumprod((1.0 - O_j / (N_j + 1)))
    logrank = np.power((res[:, 0] * res[:, 1]).sum(), 2) / ((res[:, 0] * res[:, 0] * res[:, 2]).sum())
    return logrank


In [53]:
dur_A = np.random.uniform(0, 10000, 10000)
cens_A = np.random.choice(2, 10000)
dur_B = np.random.uniform(0, 10000, 10000)
cens_B = np.random.choice(2, 10000)
weight_A = np.random.uniform(0, 1, 10000)
weight_B = np.random.uniform(0, 1, 10000)

In [49]:
weight_A

array([0.61787635, 0.87655056, 0.79484043, ..., 0.29210781, 0.38067468,
       0.29061483])

In [None]:
def optimal_criter_split(arr_nan, left, right, criterion):
    none_to = 0
    max_stat_val = 1.0
    if arr_nan.shape[1] > 0:
        left_and_nan = np.hstack([left, arr_nan])
        right_and_nan = np.hstack([right, arr_nan])
        a = criterion(left_and_nan[1], right[1], left_and_nan[0], right[0])
        b = criterion(left[1], right_and_nan[1], left[0], right_and_nan[0])
        # Nans move to a leaf with maximal statistical value
        none_to = int(a < b)
        max_stat_val = max(a, b)
    else:
        max_stat_val = criterion(left[1], right[1], left[0], right[0])
    return (max_stat_val, none_to)


def get_attrs(max_stat_val, values, none_to, l_sh, r_sh, nan_sh):
    attrs = dict()
    attrs["stat_val"] = max_stat_val
    attrs["values"] = values
    if none_to:
        attrs["pos_nan"] = [0, 1]
        attrs["min_split"] = min(l_sh, r_sh+nan_sh)
    else:
        attrs["pos_nan"] = [1, 0]
        attrs["min_split"] = min(l_sh+nan_sh, r_sh)
    return attrs


def get_cont_attrs(uniq_set, arr_notnan, arr_nan, min_samples_leaf, criterion, 
                   signif_val, thres_cont_bin_max):
    if uniq_set.shape[0] > thres_cont_bin_max:
        uniq_set = np.quantile(arr_notnan[0], [i/float(thres_cont_bin_max) for i in range(1, thres_cont_bin_max)])
    else:  # Set intermediate points
        uniq_set = (uniq_set[:-1] + uniq_set[1:])*0.5
    uniq_set = list(set(np.round(uniq_set, 3)))
    attr_dicts = []
    for value in uniq_set:
        # Filter by attr value
        ind = arr_notnan[0] >= value
        left = arr_notnan[1:, np.where(ind)[0]].astype(np.int32)
        right = arr_notnan[1:, np.where(~ind)[0]].astype(np.int32)
        if min(left.shape[1], right.shape[1]) <= min_samples_leaf:
            continue
        max_stat_val, none_to = optimal_criter_split(arr_nan, left, right, criterion)
        if max_stat_val >= signif_val:
            attr_loc = get_attrs(max_stat_val, value, none_to,
                                 left.shape[1], right.shape[1], arr_nan.shape[1])
            attr_dicts.append(attr_loc)
    return attr_dicts


def get_categ_attrs(uniq_set, arr_notnan, arr_nan, min_samples_leaf, criterion, signif_val):
    attr_dicts = []
    pairs_uniq = power_set_nonover(uniq_set)
    for l, r in pairs_uniq:
        left = arr_notnan[1:, np.isin(arr_notnan[0], l)].astype(np.int32)
        right = arr_notnan[1:, np.isin(arr_notnan[0], r)].astype(np.int32)
        if min(left.shape[1], right.shape[1]) <= min_samples_leaf:
            continue
        max_stat_val, none_to = optimal_criter_split(arr_nan, left, right, criterion)
        if max_stat_val >= signif_val:
            attr_loc = get_attrs(max_stat_val, [list(l), list(r)], none_to,
                                 left.shape[1], right.shape[1], arr_nan.shape[1])
            attr_dicts.append(attr_loc)
    return attr_dicts


def hist_best_attr_split(arr, criterion="logrank", type_attr="cont", thres_cont_bin_max=100,
                         signif=1.0, signif_stat=0.0, min_samples_leaf=10, bonf=True, verbose=0, **kwargs):
    criterion = scrit.CRITERIA_DICT.get(criterion, None)
    best_attr = {"stat_val": signif_stat, "p_value": signif,
                 "sign_split": 0, "values": [], "pos_nan": [1, 0]}
    if arr.shape[1] < 2*min_samples_leaf:
        return best_attr

    ind = np.isnan(arr[0])
    arr_nan = arr[1:, np.where(ind)[0]].astype(np.int32)
    arr_notnan = arr[:, np.where(~ind)[0]]
    
    if type_attr == "woe":
        arr_notnan[0], descr_np = transform_woe(arr_notnan[0], arr_notnan[1])
        
    
    uniq_set = np.unique(arr_notnan[0])
    
    if type_attr == "categ" and uniq_set.shape[0] > 0:
        attr_dicts = get_categ_attrs(uniq_set, arr_notnan, arr_nan,
                                     min_samples_leaf, criterion, signif_stat)
    else:
        attr_dicts = get_cont_attrs(uniq_set, arr_notnan, arr_nan,
                                    min_samples_leaf, criterion, signif_stat, thres_cont_bin_max)
    
    if len(attr_dicts) == 0:
        return best_attr
    best_attr = max(attr_dicts, key=lambda x: x["stat_val"])
    best_attr["p_value"] = stats.chi2.sf(best_attr["stat_val"], df=1)
    best_attr["sign_split"] = len(attr_dicts)
    if best_attr["sign_split"] > 0:
        if type_attr == "cont":
            best_attr["values"] = [f" >= {best_attr['values']}", f" < {best_attr['values']}"]
        elif type_attr == "categ":
            best_attr["values"] = [f" in {e}" for e in best_attr["values"]]
        elif type_attr == "woe":
            ind = descr_np[1] >= best_attr["values"]
            l, r = list(descr_np[0, np.where(ind)[0]]), list(descr_np[0, np.where(~ind)[0]])
            best_attr["values"] = [f" in {e}" for e in [l, r]]
        if bonf:
            best_attr["p_value"] *= best_attr["sign_split"]
        if verbose > 0:
            print(best_attr["p_value"], len(uniq_set))
    return best_attr


In [60]:
dur = np.random.uniform(0, 10000, 10000)
cens = np.random.choice(2, 10000)
vals = np.random.uniform(0, 10000, 10000)

times = np.unique(dur)
if times.shape[0] > 100:
    bins = np.quantile(times, [i/float(100) for i in range(1, 100)])
else:
    bins = (times[:-1] + times[1:])*0.5
        
times_range = np.array([1, times.shape[0]], dtype=np.int32)

hist_vals = bincount(inds)
np.histogram(dur, bins=bins, range=times_range)[0]

inds = np.digitize(dur_A, bins)
bincount()

0