In [78]:
from lifelines.utils import concordance_index
from numba import njit, jit

In [3]:
import numpy as np
from lifelines.utils.btree import _BTree

def concordance_index(event_times, predicted_scores, event_observed=None) -> float:
    if event_observed is None:
        event_observed = np.ones(event_times.shape[0], dtype=float)
        
    num_correct, num_tied, num_pairs = _concordance_summary_statistics(event_times, predicted_scores, event_observed)

    if num_pairs == 0:
        raise ZeroDivisionError("No admissable pairs in the dataset.")
    return (num_correct + num_tied / 2) / num_pairs


def _concordance_summary_statistics(event_times, predicted_event_times, event_observed):
    if np.logical_not(event_observed).all():
        return (0, 0, 0)

    died_mask = event_observed.astype(bool)
    died_truth = event_times[died_mask]
    ix = np.argsort(died_truth)
    died_truth = died_truth[ix]
    died_pred = predicted_event_times[died_mask][ix]

    censored_truth = event_times[~died_mask]
    ix = np.argsort(censored_truth)
    censored_truth = censored_truth[ix]
    censored_pred = predicted_event_times[~died_mask][ix]

    censored_ix = 0
    died_ix = 0
    times_to_compare = _BTree(np.unique(died_pred))
    print(np.unique(died_pred), times_to_compare)
    num_pairs = np.int64(0)
    num_correct = np.int64(0)
    num_tied = np.int64(0)

    # we iterate through cases sorted by exit time:
    # - First, all cases that died at time t0. We add these to the sortedlist of died times.
    # - Then, all cases that were censored at time t0. We DON'T add these since they are NOT
    #   comparable to subsequent elements.
    while True:
        has_more_censored = censored_ix < len(censored_truth)
        has_more_died = died_ix < len(died_truth)
        # Should we look at some censored indices next, or died indices?
        if has_more_censored and (not has_more_died or died_truth[died_ix] > censored_truth[censored_ix]):
            pairs, correct, tied, next_ix = _handle_pairs(censored_truth, censored_pred, censored_ix, times_to_compare)
            censored_ix = next_ix
        elif has_more_died and (not has_more_censored or died_truth[died_ix] <= censored_truth[censored_ix]):
            pairs, correct, tied, next_ix = _handle_pairs(died_truth, died_pred, died_ix, times_to_compare)
            for pred in died_pred[died_ix:next_ix]:
                times_to_compare.insert(pred)
            died_ix = next_ix
        else:
            assert not (has_more_died or has_more_censored)
            break

        num_pairs += pairs
        num_correct += correct
        num_tied += tied

    return (num_correct, num_tied, num_pairs)


def _handle_pairs(truth, pred, first_ix, times_to_compare):
    """
    Handle all pairs that exited at the same time as truth[first_ix].

    Returns
    -------
      (pairs, correct, tied, next_ix)
      new_pairs: The number of new comparisons performed
      new_correct: The number of comparisons correctly predicted
      next_ix: The next index that needs to be handled
    """
    next_ix = first_ix
    while next_ix < len(truth) and truth[next_ix] == truth[first_ix]:
        next_ix += 1
    pairs = len(times_to_compare) * (next_ix - first_ix)
    correct = np.int64(0)
    tied = np.int64(0)
    for i in range(first_ix, next_ix):
        rank, count = times_to_compare.rank(pred[i])
        correct += rank
        tied += count

    return (pairs, correct, tied, next_ix)

In [22]:
def _concordance_index(risk, T, E, include_ties=True):
    N = len(risk)
    censored_survival = []
    C = 0
    w = 0
    weightedPairs = 0
    weightedConcPairs = 0

    print(T, E, risk)
    for i in range(N):
        if E[i] == 1:
            for j in range(i + 1, N):
                if T[i] < T[j] or (T[i] == T[j] and E[j] == 0):
                    weightedPairs += 1
                    if risk[i] > risk[j]:
                        weightedConcPairs += 1
                    elif include_ties:
                        weightedConcPairs += 1 / 2
    C = weightedConcPairs / weightedPairs
    C = max(C, 1 - C)

    return {
        'C': C,
        'nb_pairs': 2 * weightedPairs,
        'nb_concordant_pairs': 2 * weightedConcPards
    }


def concordance_index(true_time, pred_time, event, include_ties = True, additional_results=False, **kwargs):
    order = np.argsort(-true_time)
    pred_time = pred_time[order]
    true_time = true_time[order]
    event = event[order]

    # Calculating th c-index
    results = _concordance_index(pred_time, true_time, event, include_ties)

    if not additional_results:
        return results[0]
    return results

In [79]:
@njit
def concordance_index_self(T, P, E):
    """
    Calculates the concordance index (C-index) for survival analysis.

    Args:
    T: Array of true event times.
    P: Array of predicted event times.
    E: Array of event indicators (1 if event occurred, 0 if censored).

    Returns:
    The concordance index.
    """
    order = np.argsort(T)
    P = P[order]
    T = T[order]
    E = E[order]
    
    n = len(T)
    concordant_pairs = 0
    total_pairs = 0
    for i in range(n):
        for j in range(i + 1, n):
            if E[i] == 1 and T[i] <= T[j]:
                total_pairs += 1
                if P[i] < P[j]:
                    concordant_pairs += 1
                elif P[i] == P[j]:
                    concordant_pairs += 0.5
    if total_pairs == 0:
        return 0
    return concordant_pairs / total_pairs


In [80]:
a = (np.array([10, 20, 30, 40]), 
        np.array([20, 19, 29, 39]), 
        np.array([1, 0, 1, 0]))
assert concordance_index(*a) == concordance_index_self(*a)

In [71]:
for i in range(1000):
    a = np.random.rand(100)*100
    b = np.random.rand(100)*100
    e = np.round(np.random.rand(100))
    
    assert concordance_index(a, b, e) == concordance_index_self(a, b, e), f"{a}, {b}, {e}"

In [81]:
def test_f(f):
    a = 0
    for i in range(1000):
        a = np.random.rand(100)*100
        b = np.random.rand(100)*100
        e = np.round(np.random.rand(100))
        a += f(a, b, e)
    return a

In [82]:
%timeit test_f(concordance_index)

672 ms ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [83]:
%timeit test_f(concordance_index_self)

22.9 ms ± 197 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [84]:
672/22.9

29.344978165938866

In [52]:
a, b, e

(array([43.54496772, 76.8872837 , 62.34868269, 76.74680946]),
 array([44.25585109,  1.58026784, 94.81912392,  9.99180181]),
 array([1., 1., 1., 1.]))