Skip to content

Commit

Permalink
Fixed RSF score
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonycarbone committed Mar 31, 2023
1 parent e668129 commit 32c643b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
69 changes: 55 additions & 14 deletions surpyval/regression/forest/forest.py
@@ -1,3 +1,4 @@
import math
from itertools import combinations

import numpy as np
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
else:
bootstrap_indices = [np.array(range(len(self.x)))] * self.n_trees

self.trees = Parallel(prefer="threads", verbose=1)( # Parallelise
self.trees: list[Tree] = Parallel(prefer="threads", verbose=1)(
delayed(Tree)(
x=self.x[bootstrap_indices[i]],
Z=self.Z[bootstrap_indices[i]],
Expand Down Expand Up @@ -116,16 +117,27 @@ def _apply_model_function_to_trees(
) -> NDArray:
# Prep input - make sure numpy array
x = np.array(x, ndmin=1)
Z = np.array(Z, ndmin=1)
Z = np.array(Z, ndmin=2)

# If there are multiple covariant vector samples, ?
if Z.shape[0] > 1 and x.ndim == 1:
x = np.array(x, ndmin=2).transpose()

res = np.zeros_like(x, dtype=float)
for tree in self.trees:
res += tree.apply_model_function(function_name, x, Z)
for i_covariant_vector in range(Z.shape[0]):
for tree in self.trees:
res[i_covariant_vector] += tree.apply_model_function(
function_name, x[i_covariant_vector], Z[i_covariant_vector]
)
return res / self.n_trees

def score(
self, x: ArrayLike, Z: ArrayLike | NDArray, c: ArrayLike
) -> float:
self,
x: ArrayLike,
Z: ArrayLike | NDArray,
c: ArrayLike,
tie_tol: float = 1e-8,
) -> dict:
"""Returns the concordance index of the model
Parameters
Expand Down Expand Up @@ -158,11 +170,11 @@ def score(
Z = np.array(Z, ndmin=2)

# Package xcZ together
xcZ = []
ixcZ = []
for i in range(len(x)):
xcZ.append((i, x[i], c[i], Z[i]))
ixcZ.append((i, x[i], c[i], Z[i]))

pairs = combinations(xcZ, 2)
pairs = combinations(ixcZ, 2)

def predict(i, x, Z):
"""Inner function to get memoised prediction if available,
Expand All @@ -178,18 +190,26 @@ def predict(i, x, Z):
memoised_predictions = {i: None for i in range(len(x))}
concordance = 0.0
n_permissible_pairs = 0
n_concordant_pairs = 0
n_tied_predictions = 0
n_discordant_pairs = 0
n_tied_time_samples = 0

for tup_1, tup_2 in pairs:
# Get right ordering
if tup_1[1] > tup_1[1]:
if tup_1[1] > tup_2[1]:
tup_1, tup_2 = tup_2, tup_1

# Unpack tuple
i_1, x_1, c_1, Z_1 = tup_1
i_2, x_2, c_2, Z_2 = tup_2

# Omit pair if x_1 is censored
if c_1 == 1:
if c_1 == 1 and x_1 < x_2:
continue

# Omit pair if x_1 == x_2 is censored
if x_1 == x_2 and c_1 == c_2 == 1:
continue

n_permissible_pairs += 1
Expand All @@ -200,12 +220,33 @@ def predict(i, x, Z):
if x_1 < x_2:
if x_hat_1 < x_hat_2:
concordance += 1
elif x_hat_1 == x_hat_2:
n_concordant_pairs += 1
elif math.isclose(x_hat_1, x_hat_2, abs_tol=tie_tol):
concordance += 0.5
n_tied_predictions += 1
else:
n_discordant_pairs += 1
elif c_1 == c_2 == 0:
if x_hat_1 == x_hat_2:
n_tied_time_samples += 1
if math.isclose(x_hat_1, x_hat_2, abs_tol=tie_tol):
concordance += 1
else:
concordance += 0.5
else:
# x_1 == x_2 andone is a death
if (c_1 == 0 and x_hat_1 < x_hat_2) or (
c_2 == 0 and x_hat_2 < x_hat_1
):
concordance += 1
else:
concordance += 0.5

return concordance / n_permissible_pairs
return {
"c_index": concordance / n_permissible_pairs,
"n_concordant_pairs": n_concordant_pairs,
"n_discordant_pairs": n_discordant_pairs,
"n_tied_predictions": n_tied_predictions,
"n_tied_time_samples": n_tied_time_samples,
"concordance": concordance,
"n_permissible_pairs": n_permissible_pairs,
}
2 changes: 1 addition & 1 deletion surpyval/tests/forest/test_forest.py
Expand Up @@ -86,7 +86,7 @@ def test_forest_sf_scalar_x(

sf_100 = forest.sf(x=100, Z=[0.5, 0.5])
assert isinstance(sf_100, np.ndarray)
assert pytest.approx(sf_100, abs=0.1) == np.array([0.1])
assert pytest.approx(sf_100, abs=0.15) == np.array([0.1])
# (Veeery approximate) ^^^^^^^


Expand Down

0 comments on commit 32c643b

Please sign in to comment.