Skip to content

Commit

Permalink
Fixed log-rank split by using scikit-survival's implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonycarbone committed Mar 21, 2023
1 parent 13ff3ec commit 5c6767a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 142 deletions.
138 changes: 29 additions & 109 deletions surpyval/regression/forest/log_rank_split.py
@@ -1,8 +1,6 @@
from math import sqrt
from typing import Iterable

import numpy as np
import pytest
from numpy.typing import NDArray
from sksurv.compare import compare_survival

Expand Down Expand Up @@ -52,6 +50,9 @@ def log_rank_split(
Remembering, the return split is for the left childs feature
:math:`u^* \leq v^*`, and right child :math:`u^* > v^*`.
Note: It wraps scikit-survival's compare_survival() function. In the future
this should be a native implementation.
Parameters
----------
Expand All @@ -69,35 +70,15 @@ def log_rank_split(
be (-1, -Inf) if insufficient samples were provided to satisfy the
min_leaf_failures constraint.
"""
# Sort x, Z, and c, in x, making sure Z is 2d (required for future calcs)
sort_idxs = np.argsort(x)
x = x[sort_idxs]
c = c[sort_idxs]
# Transform to scikit-survival form
if Z.ndim == 1:
Z = np.reshape(Z[sort_idxs], (1, -1)).transpose()
else:
Z[sort_idxs]

# Calculate the d vector, where d[j] is the number of distinct deaths
# at t[j]
death_indices = np.where(c == 0)[0] # Uncensored => deaths
death_xs = x[death_indices] # Death times

# The 't' and 'd' vectors, that is t[j] is the j-th dinstinct (uncensored)
# death time, and d[j] is the number of distinct deaths at that time
t, d = np.unique(death_xs, return_counts=True)
m = len(t) # How many unique times in the samples there are

# Now the Y vector needs to be calculated
# Y[j] = the number of people still 'at risk' at time t[j], that is the
# sum of samples who's survival times are greater than t[j] (irrelevant
# of censorship)
Y = np.array([len(x[x >= t_j]) for t_j in t])

# Now let's find the best (u, v) pair
max_log_rank_magnitude = float("-inf")
best_u = -1 # Placeholder value
best_v = -float("inf") # Placeholder value
Z = np.reshape(Z, (1, -1)).transpose()
X, y = surv_xZc_to_sksurv_Xy(x=x, Z=Z, c=c)

# Best values
best_feature_index = -1
best_feature_value = float("-inf")
best_log_rank = float("-inf")

# Inner function used in ensuring the min_leaf_failures constraint is
# respected
Expand All @@ -115,87 +96,26 @@ def breaks_min_leaf_failures_constraint():
return True
return False

for u in feature_indices_in:
for v in np.unique(Z[:, u])[:-1]:
# Discard the (u, v) pair if it means a leaf will
# have < min_leaf_failures samples
# Loop over features
for u in range(X.shape[1]):
possible_feature_values = np.unique(X[:, u])

# If there's <2 unique values to consider, move on to the next feature
if len(possible_feature_values) < 2:
continue

# Else, go over each possible feature value
for i, v in enumerate(possible_feature_values):
if breaks_min_leaf_failures_constraint():
continue

abs_log_rank = abs(log_rank(u, v, x, Z, c, t, d, Y, m))

if assert_reference:
reference_abs_log_rank, _ = compare_survival(
surv_xZc_to_sksurv_Xy(x, Z, c)[1],
(Z[:, u] <= v).astype(int),
)
try:
assert (
pytest.approx(abs_log_rank) == reference_abs_log_rank
)
except AssertionError:
raise AssertionError(
f"abs_log_rank={abs_log_rank:.3f} != "
f"reference_abs_log_rank={reference_abs_log_rank:.3f}"
)

if abs_log_rank > max_log_rank_magnitude:
max_log_rank_magnitude = abs_log_rank
best_u = u
best_v = v

return best_u, best_v


def log_rank(
u: int,
v: float,
x: NDArray,
Z: NDArray,
c: NDArray,
t: NDArray,
d: NDArray,
Y: NDArray,
m: int,
) -> float:
"""Returns L(u, v)."""

# Define the d_L and Y_L vectors
d_L = np.zeros(m)
Y_L = np.zeros(m)

# Get sample-indices (i) of those that would end up in the left child
left_child_indices = np.where(Z[:, u] <= v)[0]
left_child_x = x[left_child_indices]
left_child_c = c[left_child_indices]

for j in range(m):
# Number of uncensored deaths at t[j]
d_L[j] = sum(left_child_x[left_child_c == 0] == t[j])

# Number 'at risk', that is those still alive + have an event (death
# or censor) at t[j]
Y_L[j] = sum(left_child_x >= t[j])

# Perform the j-sums
numerator = 0
denominator_inside_sqrt = 0 # Must sqrt() after j loop

for j in range(m):
if not Y[j] >= 2:
# Denominator contribution would be undefined
# (Y[j] - 1 <= 0 -> bad!)
continue
numerator += d_L[j] - Y_L[j] * d[j] / Y[j]
denominator_inside_sqrt += (
Y_L[j]
/ Y[j]
* (1 - Y_L[j] / Y[j])
* (Y[j] - d[j])
/ (Y[j] - 1)
* d[j]
)
split = (v + possible_feature_values[i + 1]) * 0.5

L_u_v_return = numerator / sqrt(denominator_inside_sqrt)
groups = (X[:, u] <= split).astype(int)
log_rank_u_v, _ = compare_survival(y, groups)
if log_rank_u_v > best_log_rank:
best_feature_index = u
best_feature_value = split
best_log_rank = log_rank_u_v

return L_u_v_return
return best_feature_index, best_feature_value
14 changes: 5 additions & 9 deletions surpyval/tests/forest/test_log_rank_split.py
Expand Up @@ -18,14 +18,13 @@ def test_log_rank_split_one_binary_feature():
c,
min_leaf_failures=6,
feature_indices_in=[0],
assert_reference=True,
)

# Assert feature 0 (the only feature) is returned
assert lrs[0] == 0

# Assert feature 0 value 0 (left children have Z_0 <= 0)
assert lrs[1] == 0
# Assert feature 0 value 0.5 (left children have Z_0 <= 0)
assert lrs[1] == 0.5


def test_log_rank_split_one_feature_four_samples():
Expand All @@ -39,11 +38,10 @@ def test_log_rank_split_one_feature_four_samples():
c,
min_leaf_failures=1,
feature_indices_in=[0],
assert_reference=True,
)

assert lrs[0] == 0
assert lrs[1] == 0
assert lrs[1] == 0.1


def test_log_rank_split_two_features_two_samples():
Expand All @@ -61,11 +59,10 @@ def test_log_rank_split_two_features_two_samples():
c,
min_leaf_failures=1,
feature_indices_in=[0, 1],
assert_reference=True,
)

assert lrs[0] == 1
assert lrs[1] == 1
assert lrs[1] == 2


def test_log_rank_split_min_leaf_failures():
Expand All @@ -85,10 +82,9 @@ def test_log_rank_split_min_leaf_failures():
c_A,
min_leaf_failures=min_leaf_failures,
feature_indices_in=[0],
assert_reference=True,
)
assert lrsA[0] == 0
assert lrsA[1] == 0.1
assert lrsA[1] == 1.55

# Case B: all samples are censored, a split is not possible
c_B = np.array([1] * len(x))
Expand Down
32 changes: 8 additions & 24 deletions surpyval/tests/forest/test_tree.py
Expand Up @@ -164,20 +164,9 @@ def dfs_assert_trees_equal(
== sklearn_tree.feature[sksurv_curr_node]
)

# Assert surpyval_value <= scikit_value < next_biggest_value
# as these would represent/result in the same split. It seems
# that scikit-survival doesn't use samples exactly but may get a median
next_biggest_value = min(
surv_curr_node.Z[
surv_curr_node.Z[:, surv_curr_node.split_feature_index]
> surv_curr_node.split_feature_value,
surv_curr_node.split_feature_index,
]
)
assert (
surv_curr_node.split_feature_value
<= sklearn_tree.threshold[sksurv_curr_node]
< next_biggest_value
pytest.approx(surv_curr_node.split_feature_value)
== sklearn_tree.threshold[sksurv_curr_node]
)

# And continue the DFS
Expand Down Expand Up @@ -248,11 +237,6 @@ def test_tree_reference_split_one_split_two_features():


def test_tree_reference_splits_gbsg2():
"""
Scikit-survival's SurvivalTree is a reference implementation for surpyval's
Tree. Here the log-rank splitter is tested to make sure the decision tree
is identical between the two.
"""
# Prep data input
X, y = load_gbsg2()

Expand All @@ -271,9 +255,9 @@ def test_tree_reference_splits_gbsg2():
min_samples_split=2,
min_samples_leaf=15,
max_features=None,
max_depth=1,
max_depth=2,
)
sksurv_tree.fit(np.array(Xt[:100]["tgrade"], ndmin=2).transpose(), y[:100])
sksurv_tree.fit(Xt, y)

# Prep and fit a surpyval Tree
def sksurv_Xy_to_surv_xZc(X: pd.DataFrame, y: NDArray):
Expand All @@ -282,12 +266,12 @@ def sksurv_Xy_to_surv_xZc(X: pd.DataFrame, y: NDArray):
x, Z, c = sksurv_Xy_to_surv_xZc(Xt, y)

surv_tree = Tree(
x=x[:100],
Z=Z[:100, 0],
c=c[:100],
x=x,
Z=Z,
c=c,
n_features_split="all",
min_leaf_failures=15,
max_depth=1,
max_depth=2,
)

assert_trees_equal(surv_tree, sksurv_tree)

0 comments on commit 5c6767a

Please sign in to comment.