Skip to content

Commit

Permalink
RandomSurvivalForest performance optimisations
Browse files Browse the repository at this point in the history
- Parallelised Tree training
- Made Weibull fitting of terminal nodes lazy
  • Loading branch information
anthonycarbone committed Mar 31, 2023
1 parent 5c6767a commit 71f967e
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 156 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Expand Up @@ -23,7 +23,8 @@ module = [
'sphinx_rtd_theme',
'setuptools',
'sklearn.*',
'sksurv.*'
'sksurv.*',
'joblib.*'
]
ignore_missing_imports = true

Expand Down
64 changes: 46 additions & 18 deletions surpyval/regression/forest/forest.py
@@ -1,4 +1,5 @@
import numpy as np
from joblib import Parallel, delayed
from numpy.typing import ArrayLike, NDArray

from surpyval.regression.forest.tree import Tree
Expand Down Expand Up @@ -34,28 +35,55 @@ def __init__(
self.n_trees = n_trees
self.bootstrap = bootstrap

# Create trees
self.trees = []
bootstrap_indices = np.array(range(len(self.x))) # Defaults to normal
for _ in range(self.n_trees):
if self.bootstrap:
bootstrap_indices = np.random.choice(
len(self.x), len(self.x), replace=True
)
self.trees.append(
Tree(
x=self.x[bootstrap_indices],
Z=self.Z[bootstrap_indices],
c=self.c[bootstrap_indices],
max_depth=max_depth,
min_leaf_failures=min_leaf_failures,
n_features_split=n_features_split,
)
# Create Trees
if self.bootstrap:
bootstrap_indices = [
np.random.choice(len(self.x), len(self.x), replace=True)
for _ in range(self.n_trees)
]
else:
bootstrap_indices = [np.array(range(len(self.x)))] * self.n_trees

self.trees = Parallel(prefer="threads", verbose=1)( # Parallelise
delayed(Tree)(
x=self.x[bootstrap_indices[i]],
Z=self.Z[bootstrap_indices[i]],
c=self.c[bootstrap_indices[i]],
max_depth=max_depth,
min_leaf_failures=min_leaf_failures,
n_features_split=n_features_split,
)
for i in range(self.n_trees)
)

def sf(
self, x: int | float | ArrayLike, Z: ArrayLike | NDArray
self,
x: int | float | ArrayLike,
Z: ArrayLike | NDArray,
ensemble_method: str = "sf",
) -> NDArray:
"""Returns the ensemble survival function
Parameters
----------
x : int | float | ArrayLike
Time samples
Z : ArrayLike | NDArray
Covariant matrix
ensemble_method : str, optional
Determines whether to average across terminal nodes the terminal
node survival functions or cumulative hazard functions.
For these respectively, ensemble_method must be "sf" or
"Hf". Defaults to "sf".
Returns
-------
NDArray
Survival function of x as 1D array
"""
if ensemble_method == "Hf":
Hf = self._apply_model_function_to_trees("Hf", x, Z)
return np.exp(-Hf)
return self._apply_model_function_to_trees("sf", x, Z)

def ff(
Expand Down
183 changes: 149 additions & 34 deletions surpyval/regression/forest/log_rank_split.py
@@ -1,12 +1,14 @@
from math import sqrt
from typing import Iterable

import numpy as np
from numpy.typing import NDArray
from sksurv.compare import compare_survival

from surpyval.utils.surv_sksurv_transformations import ( # sksurv's log-rank
surv_xZc_to_sksurv_Xy,
)
# from sksurv.compare import compare_survival

# from surpyval.utils.surv_sksurv_transformations import ( # sksurv's log-rank
# surv_xZc_to_sksurv_Xy,
# )


def log_rank_split(
Expand All @@ -15,7 +17,6 @@ def log_rank_split(
c: NDArray,
min_leaf_failures: int,
feature_indices_in: Iterable[int],
assert_reference: bool = False,
) -> tuple[int, float]:
r"""
Returns the best feature index and value according to the Log-Rank split
Expand Down Expand Up @@ -50,9 +51,6 @@ def log_rank_split(
Remembering, the return split is for the left childs feature
:math:`u^* \leq v^*`, and right child :math:`u^* > v^*`.
Note: It wraps scikit-survival's compare_survival() function. In the future
this should be a native implementation.
Parameters
----------
Expand All @@ -70,15 +68,36 @@ def log_rank_split(
be (-1, -Inf) if insufficient samples were provided to satisfy the
min_leaf_failures constraint.
"""
# Transform to scikit-survival form
if Z.ndim == 1:
Z = np.reshape(Z, (1, -1)).transpose()
X, y = surv_xZc_to_sksurv_Xy(x=x, Z=Z, c=c)

# Best values
best_feature_index = -1
best_feature_value = float("-inf")
best_log_rank = float("-inf")
# Sort x, Z, and c, in x, making sure Z is 2d (required for future calcs)
sort_idxs = np.argsort(x)
x = x[sort_idxs]
c = c[sort_idxs]
if Z.ndim == 1:
Z = np.reshape(Z[sort_idxs], (1, -1)).transpose()
else:
Z[sort_idxs]

# Calculate the d vector, where d[j] is the number of distinct deaths
# at t[j]
death_indices = np.where(c == 0)[0] # Uncensored => deaths
death_xs = x[death_indices] # Death times

# The 't' and 'd' vectors, that is t[j] is the j-th dinstinct (uncensored)
# death time, and d[j] is the number of distinct deaths at that time
t, d = np.unique(death_xs, return_counts=True)
m = len(t) # How many unique times in the samples there are

# Now the Y vector needs to be calculated
# Y[j] = the number of people still 'at risk' at time t[j], that is the
# sum of samples who's survival times are greater than t[j] (irrelevant
# of censorship)
Y = np.array([len(x[x >= t_j]) for t_j in t])

# Now let's find the best (u, v) pair
max_log_rank_magnitude = float("-inf")
best_u = -1 # Placeholder value
best_v = -float("inf") # Placeholder value

# Inner function used in ensuring the min_leaf_failures constraint is
# respected
Expand All @@ -96,26 +115,122 @@ def breaks_min_leaf_failures_constraint():
return True
return False

# Loop over features
for u in range(X.shape[1]):
possible_feature_values = np.unique(X[:, u])

# If there's <2 unique values to consider, move on to the next feature
if len(possible_feature_values) < 2:
continue

# Else, go over each possible feature value
for i, v in enumerate(possible_feature_values):
for u in feature_indices_in:
for v in np.unique(Z[:, u])[:-1]:
# Discard the (u, v) pair if it means a leaf will
# have < min_leaf_failures samples
if breaks_min_leaf_failures_constraint():
continue

split = (v + possible_feature_values[i + 1]) * 0.5
abs_log_rank = abs(log_rank(u, v, x, Z, c, t, d, Y, m))

if abs_log_rank > max_log_rank_magnitude:
max_log_rank_magnitude = abs_log_rank
best_u = u
best_v = v

groups = (X[:, u] <= split).astype(int)
log_rank_u_v, _ = compare_survival(y, groups)
if log_rank_u_v > best_log_rank:
best_feature_index = u
best_feature_value = split
best_log_rank = log_rank_u_v
return best_u, best_v


def log_rank(
u: int,
v: float,
x: NDArray,
Z: NDArray,
c: NDArray,
t: NDArray,
d: NDArray,
Y: NDArray,
m: int,
) -> float:
"""Returns L(u, v)."""

# Define the d_L and Y_L vectors
d_L = np.zeros(m)
Y_L = np.zeros(m)

# Get sample-indices (i) of those that would end up in the left child
left_child_indices = np.where(Z[:, u] <= v)[0]
left_child_x = x[left_child_indices]
left_child_c = c[left_child_indices]

for j in range(m):
# Number of uncensored deaths at t[j]
d_L[j] = np.sum(left_child_x[left_child_c == 0] == t[j])

# Number 'at risk', that is those still alive + have an event (death
# or censor) at t[j]
Y_L[j] = np.sum(left_child_x >= t[j])

# Perform the j-sums
numerator = 0
denominator_inside_sqrt = 0 # Must sqrt() after j loop

for j in range(m):
if not Y[j] >= 2:
# Denominator contribution would be undefined
# (Y[j] - 1 <= 0 -> bad!)
continue
numerator += d_L[j] - Y_L[j] * d[j] / Y[j]
denominator_inside_sqrt += (
Y_L[j]
/ Y[j]
* (1 - Y_L[j] / Y[j])
* (Y[j] - d[j])
/ (Y[j] - 1)
* d[j]
)

return best_feature_index, best_feature_value
L_u_v_return = numerator / sqrt(denominator_inside_sqrt)

return L_u_v_return

# # Transform to scikit-survival form
# if Z.ndim == 1:
# Z = np.reshape(Z, (1, -1)).transpose()
# X, y = surv_xZc_to_sksurv_Xy(x=x, Z=Z, c=c)

# # Best values
# best_feature_index = -1
# best_feature_value = float("-inf")
# best_log_rank = float("-inf")

# # Inner function used in ensuring the min_leaf_failures constraint is
# # respected
# def breaks_min_leaf_failures_constraint():
# left_child_samples = len(
# np.where(np.logical_and(Z[:, u] <= v, c == 0))[0]
# )
# right_child_samples = len(
# np.where(np.logical_and(Z[:, u] > v, c == 0))[0]
# )
# if (
# left_child_samples < min_leaf_failures
# or right_child_samples < min_leaf_failures
# ):
# return True
# return False

# # Loop over features
# for u in range(X.shape[1]):
# possible_feature_values = np.unique(X[:, u])

# #If there's <2 unique values to consider, move on to the next feature
# if len(possible_feature_values) < 2:
# continue

# # Else, go over each possible feature value
# for i, v in enumerate(possible_feature_values):
# if breaks_min_leaf_failures_constraint():
# continue

# split = (v + possible_feature_values[i + 1]) * 0.5

# groups = (X[:, u] <= split).astype(int)
# log_rank_u_v, _ = compare_survival(y, groups)
# if log_rank_u_v > best_log_rank:
# best_feature_index = u
# best_feature_value = split
# best_log_rank = log_rank_u_v

# return best_feature_index, best_feature_value
8 changes: 7 additions & 1 deletion surpyval/regression/forest/node.py
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from functools import cached_property

import numpy as np
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -79,7 +80,12 @@ def apply_model_function(

class TerminalNode(Node):
def __init__(self, x: NDArray, c: NDArray):
self.model = Weibull.fit(x, c)
self.x = x
self.c = c

@cached_property
def model(self):
return Weibull.fit(self.x, self.c)

def apply_model_function(
self,
Expand Down
10 changes: 5 additions & 5 deletions surpyval/tests/forest/test_log_rank_split.py
Expand Up @@ -23,8 +23,8 @@ def test_log_rank_split_one_binary_feature():
# Assert feature 0 (the only feature) is returned
assert lrs[0] == 0

# Assert feature 0 value 0.5 (left children have Z_0 <= 0)
assert lrs[1] == 0.5
# Assert feature 0 value 0 (left children have Z_0 <= 0)
assert lrs[1] == 0


def test_log_rank_split_one_feature_four_samples():
Expand All @@ -41,7 +41,7 @@ def test_log_rank_split_one_feature_four_samples():
)

assert lrs[0] == 0
assert lrs[1] == 0.1
assert lrs[1] == 0


def test_log_rank_split_two_features_two_samples():
Expand All @@ -62,7 +62,7 @@ def test_log_rank_split_two_features_two_samples():
)

assert lrs[0] == 1
assert lrs[1] == 2
assert lrs[1] == 1


def test_log_rank_split_min_leaf_failures():
Expand All @@ -84,7 +84,7 @@ def test_log_rank_split_min_leaf_failures():
feature_indices_in=[0],
)
assert lrsA[0] == 0
assert lrsA[1] == 1.55
assert lrsA[1] == 0.1

# Case B: all samples are censored, a split is not possible
c_B = np.array([1] * len(x))
Expand Down

0 comments on commit 71f967e

Please sign in to comment.