From 71f967e3d6bd216f4e6c8fabd5a646c56efc42ed Mon Sep 17 00:00:00 2001 From: Anthony Carbone Date: Tue, 28 Mar 2023 14:46:41 +1100 Subject: [PATCH] RandomSurvivalForest performance optimisations - Parallelised Tree training - Made Weibull fitting of terminal nodes lazy --- pyproject.toml | 3 +- surpyval/regression/forest/forest.py | 64 +++++-- surpyval/regression/forest/log_rank_split.py | 183 ++++++++++++++---- surpyval/regression/forest/node.py | 8 +- surpyval/tests/forest/test_log_rank_split.py | 10 +- surpyval/tests/forest/test_tree.py | 191 +++++++++---------- 6 files changed, 303 insertions(+), 156 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a142c0..7ba4866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ module = [ 'sphinx_rtd_theme', 'setuptools', 'sklearn.*', - 'sksurv.*' + 'sksurv.*', + 'joblib.*' ] ignore_missing_imports = true diff --git a/surpyval/regression/forest/forest.py b/surpyval/regression/forest/forest.py index 12e1961..14854c1 100644 --- a/surpyval/regression/forest/forest.py +++ b/surpyval/regression/forest/forest.py @@ -1,4 +1,5 @@ import numpy as np +from joblib import Parallel, delayed from numpy.typing import ArrayLike, NDArray from surpyval.regression.forest.tree import Tree @@ -34,28 +35,55 @@ def __init__( self.n_trees = n_trees self.bootstrap = bootstrap - # Create trees - self.trees = [] - bootstrap_indices = np.array(range(len(self.x))) # Defaults to normal - for _ in range(self.n_trees): - if self.bootstrap: - bootstrap_indices = np.random.choice( - len(self.x), len(self.x), replace=True - ) - self.trees.append( - Tree( - x=self.x[bootstrap_indices], - Z=self.Z[bootstrap_indices], - c=self.c[bootstrap_indices], - max_depth=max_depth, - min_leaf_failures=min_leaf_failures, - n_features_split=n_features_split, - ) + # Create Trees + if self.bootstrap: + bootstrap_indices = [ + np.random.choice(len(self.x), len(self.x), replace=True) + for _ in range(self.n_trees) + ] + else: + bootstrap_indices = [np.array(range(len(self.x)))] * self.n_trees + + self.trees = Parallel(prefer="threads", verbose=1)( # Parallelise + delayed(Tree)( + x=self.x[bootstrap_indices[i]], + Z=self.Z[bootstrap_indices[i]], + c=self.c[bootstrap_indices[i]], + max_depth=max_depth, + min_leaf_failures=min_leaf_failures, + n_features_split=n_features_split, ) + for i in range(self.n_trees) + ) def sf( - self, x: int | float | ArrayLike, Z: ArrayLike | NDArray + self, + x: int | float | ArrayLike, + Z: ArrayLike | NDArray, + ensemble_method: str = "sf", ) -> NDArray: + """Returns the ensemble survival function + + Parameters + ---------- + x : int | float | ArrayLike + Time samples + Z : ArrayLike | NDArray + Covariant matrix + ensemble_method : str, optional + Determines whether to average across terminal nodes the terminal + node survival functions or cumulative hazard functions. + For these respectively, ensemble_method must be "sf" or + "Hf". Defaults to "sf". + + Returns + ------- + NDArray + Survival function of x as 1D array + """ + if ensemble_method == "Hf": + Hf = self._apply_model_function_to_trees("Hf", x, Z) + return np.exp(-Hf) return self._apply_model_function_to_trees("sf", x, Z) def ff( diff --git a/surpyval/regression/forest/log_rank_split.py b/surpyval/regression/forest/log_rank_split.py index c68e109..d61c95a 100644 --- a/surpyval/regression/forest/log_rank_split.py +++ b/surpyval/regression/forest/log_rank_split.py @@ -1,12 +1,14 @@ +from math import sqrt from typing import Iterable import numpy as np from numpy.typing import NDArray -from sksurv.compare import compare_survival -from surpyval.utils.surv_sksurv_transformations import ( # sksurv's log-rank - surv_xZc_to_sksurv_Xy, -) +# from sksurv.compare import compare_survival + +# from surpyval.utils.surv_sksurv_transformations import ( # sksurv's log-rank +# surv_xZc_to_sksurv_Xy, +# ) def log_rank_split( @@ -15,7 +17,6 @@ def log_rank_split( c: NDArray, min_leaf_failures: int, feature_indices_in: Iterable[int], - assert_reference: bool = False, ) -> tuple[int, float]: r""" Returns the best feature index and value according to the Log-Rank split @@ -50,9 +51,6 @@ def log_rank_split( Remembering, the return split is for the left childs feature :math:`u^* \leq v^*`, and right child :math:`u^* > v^*`. - Note: It wraps scikit-survival's compare_survival() function. In the future - this should be a native implementation. - Parameters ---------- @@ -70,15 +68,36 @@ def log_rank_split( be (-1, -Inf) if insufficient samples were provided to satisfy the min_leaf_failures constraint. """ - # Transform to scikit-survival form - if Z.ndim == 1: - Z = np.reshape(Z, (1, -1)).transpose() - X, y = surv_xZc_to_sksurv_Xy(x=x, Z=Z, c=c) - # Best values - best_feature_index = -1 - best_feature_value = float("-inf") - best_log_rank = float("-inf") + # Sort x, Z, and c, in x, making sure Z is 2d (required for future calcs) + sort_idxs = np.argsort(x) + x = x[sort_idxs] + c = c[sort_idxs] + if Z.ndim == 1: + Z = np.reshape(Z[sort_idxs], (1, -1)).transpose() + else: + Z[sort_idxs] + + # Calculate the d vector, where d[j] is the number of distinct deaths + # at t[j] + death_indices = np.where(c == 0)[0] # Uncensored => deaths + death_xs = x[death_indices] # Death times + + # The 't' and 'd' vectors, that is t[j] is the j-th dinstinct (uncensored) + # death time, and d[j] is the number of distinct deaths at that time + t, d = np.unique(death_xs, return_counts=True) + m = len(t) # How many unique times in the samples there are + + # Now the Y vector needs to be calculated + # Y[j] = the number of people still 'at risk' at time t[j], that is the + # sum of samples who's survival times are greater than t[j] (irrelevant + # of censorship) + Y = np.array([len(x[x >= t_j]) for t_j in t]) + + # Now let's find the best (u, v) pair + max_log_rank_magnitude = float("-inf") + best_u = -1 # Placeholder value + best_v = -float("inf") # Placeholder value # Inner function used in ensuring the min_leaf_failures constraint is # respected @@ -96,26 +115,122 @@ def breaks_min_leaf_failures_constraint(): return True return False - # Loop over features - for u in range(X.shape[1]): - possible_feature_values = np.unique(X[:, u]) - - # If there's <2 unique values to consider, move on to the next feature - if len(possible_feature_values) < 2: - continue - - # Else, go over each possible feature value - for i, v in enumerate(possible_feature_values): + for u in feature_indices_in: + for v in np.unique(Z[:, u])[:-1]: + # Discard the (u, v) pair if it means a leaf will + # have < min_leaf_failures samples if breaks_min_leaf_failures_constraint(): continue - split = (v + possible_feature_values[i + 1]) * 0.5 + abs_log_rank = abs(log_rank(u, v, x, Z, c, t, d, Y, m)) + + if abs_log_rank > max_log_rank_magnitude: + max_log_rank_magnitude = abs_log_rank + best_u = u + best_v = v - groups = (X[:, u] <= split).astype(int) - log_rank_u_v, _ = compare_survival(y, groups) - if log_rank_u_v > best_log_rank: - best_feature_index = u - best_feature_value = split - best_log_rank = log_rank_u_v + return best_u, best_v + + +def log_rank( + u: int, + v: float, + x: NDArray, + Z: NDArray, + c: NDArray, + t: NDArray, + d: NDArray, + Y: NDArray, + m: int, +) -> float: + """Returns L(u, v).""" + + # Define the d_L and Y_L vectors + d_L = np.zeros(m) + Y_L = np.zeros(m) + + # Get sample-indices (i) of those that would end up in the left child + left_child_indices = np.where(Z[:, u] <= v)[0] + left_child_x = x[left_child_indices] + left_child_c = c[left_child_indices] + + for j in range(m): + # Number of uncensored deaths at t[j] + d_L[j] = np.sum(left_child_x[left_child_c == 0] == t[j]) + + # Number 'at risk', that is those still alive + have an event (death + # or censor) at t[j] + Y_L[j] = np.sum(left_child_x >= t[j]) + + # Perform the j-sums + numerator = 0 + denominator_inside_sqrt = 0 # Must sqrt() after j loop + + for j in range(m): + if not Y[j] >= 2: + # Denominator contribution would be undefined + # (Y[j] - 1 <= 0 -> bad!) + continue + numerator += d_L[j] - Y_L[j] * d[j] / Y[j] + denominator_inside_sqrt += ( + Y_L[j] + / Y[j] + * (1 - Y_L[j] / Y[j]) + * (Y[j] - d[j]) + / (Y[j] - 1) + * d[j] + ) - return best_feature_index, best_feature_value + L_u_v_return = numerator / sqrt(denominator_inside_sqrt) + + return L_u_v_return + + # # Transform to scikit-survival form + # if Z.ndim == 1: + # Z = np.reshape(Z, (1, -1)).transpose() + # X, y = surv_xZc_to_sksurv_Xy(x=x, Z=Z, c=c) + + # # Best values + # best_feature_index = -1 + # best_feature_value = float("-inf") + # best_log_rank = float("-inf") + + # # Inner function used in ensuring the min_leaf_failures constraint is + # # respected + # def breaks_min_leaf_failures_constraint(): + # left_child_samples = len( + # np.where(np.logical_and(Z[:, u] <= v, c == 0))[0] + # ) + # right_child_samples = len( + # np.where(np.logical_and(Z[:, u] > v, c == 0))[0] + # ) + # if ( + # left_child_samples < min_leaf_failures + # or right_child_samples < min_leaf_failures + # ): + # return True + # return False + + # # Loop over features + # for u in range(X.shape[1]): + # possible_feature_values = np.unique(X[:, u]) + + # #If there's <2 unique values to consider, move on to the next feature + # if len(possible_feature_values) < 2: + # continue + + # # Else, go over each possible feature value + # for i, v in enumerate(possible_feature_values): + # if breaks_min_leaf_failures_constraint(): + # continue + + # split = (v + possible_feature_values[i + 1]) * 0.5 + + # groups = (X[:, u] <= split).astype(int) + # log_rank_u_v, _ = compare_survival(y, groups) + # if log_rank_u_v > best_log_rank: + # best_feature_index = u + # best_feature_value = split + # best_log_rank = log_rank_u_v + + # return best_feature_index, best_feature_value diff --git a/surpyval/regression/forest/node.py b/surpyval/regression/forest/node.py index d837292..b3a3f42 100644 --- a/surpyval/regression/forest/node.py +++ b/surpyval/regression/forest/node.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import cached_property import numpy as np from numpy.typing import ArrayLike, NDArray @@ -79,7 +80,12 @@ def apply_model_function( class TerminalNode(Node): def __init__(self, x: NDArray, c: NDArray): - self.model = Weibull.fit(x, c) + self.x = x + self.c = c + + @cached_property + def model(self): + return Weibull.fit(self.x, self.c) def apply_model_function( self, diff --git a/surpyval/tests/forest/test_log_rank_split.py b/surpyval/tests/forest/test_log_rank_split.py index 56567e1..bbc51f7 100644 --- a/surpyval/tests/forest/test_log_rank_split.py +++ b/surpyval/tests/forest/test_log_rank_split.py @@ -23,8 +23,8 @@ def test_log_rank_split_one_binary_feature(): # Assert feature 0 (the only feature) is returned assert lrs[0] == 0 - # Assert feature 0 value 0.5 (left children have Z_0 <= 0) - assert lrs[1] == 0.5 + # Assert feature 0 value 0 (left children have Z_0 <= 0) + assert lrs[1] == 0 def test_log_rank_split_one_feature_four_samples(): @@ -41,7 +41,7 @@ def test_log_rank_split_one_feature_four_samples(): ) assert lrs[0] == 0 - assert lrs[1] == 0.1 + assert lrs[1] == 0 def test_log_rank_split_two_features_two_samples(): @@ -62,7 +62,7 @@ def test_log_rank_split_two_features_two_samples(): ) assert lrs[0] == 1 - assert lrs[1] == 2 + assert lrs[1] == 1 def test_log_rank_split_min_leaf_failures(): @@ -84,7 +84,7 @@ def test_log_rank_split_min_leaf_failures(): feature_indices_in=[0], ) assert lrsA[0] == 0 - assert lrsA[1] == 1.55 + assert lrsA[1] == 0.1 # Case B: all samples are censored, a split is not possible c_B = np.array([1] * len(x)) diff --git a/surpyval/tests/forest/test_tree.py b/surpyval/tests/forest/test_tree.py index 364262f..246a9fc 100644 --- a/surpyval/tests/forest/test_tree.py +++ b/surpyval/tests/forest/test_tree.py @@ -1,12 +1,7 @@ import numpy as np -import pandas as pd import pytest -from numpy.typing import NDArray -from sklearn.preprocessing import OrdinalEncoder # For scikit-survival test -from sksurv.datasets import load_gbsg2 -from sksurv.preprocessing import OneHotEncoder from sksurv.tree.tree import SurvivalTree as sksurv_SurvivalTree from surpyval import Weibull @@ -183,95 +178,97 @@ def dfs_assert_trees_equal( dfs_assert_trees_equal(surv_curr_node, sksurv_curr_node) -def test_tree_reference_split_one_split_one_feature(): - # Samples - x = [10, 12, 8, 9, 11, 12, 13, 9] + [50, 60, 40, 45, 55, 60, 65, 45] - Z = [0] * 8 + [1] * 8 - c = [0] * len(x) - - # Surpyval - surv_tree = Tree(x=x, Z=Z, c=c, max_depth=1, n_features_split="all") - - # Scikit-survival - X = np.array(Z, ndmin=2).transpose() - y = np.array( - list(zip([True] * len(x), x)), - dtype=[("Status", bool), ("Survival", "