In [137]:
import numpy as np
import torch



In [138]:
time_np = np.array([1.0,1.0,1.0,3.0,5.0,5.0,5.0,8.0])
event_np = np.array([1.0,1.0,1.0,1.0,1.0,1.0,0.0,1.0])
time_torch = torch.tensor(time_np)
event_torch = torch.tensor(event_np)

In [323]:
def KaplanMeier(time: np.array, event: np.array, 
                cens_dist: bool = False
) -> tuple[np.array, np.array] | tuple[np.array, np.array, np.array]:
    """_summary_

    Parameters
    ----------
    time : npt.NDArray[float]
        _description_
    event : npt.NDArray[int]
        _description_
    cens_dist : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[npt.NDArray[float], npt.NDArray[float]] | tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int]]
        _description_
    
    References
    ----------
    .. [1] Kaplan, E. L. and Meier, P., "Nonparametric estimation from incomplete observations",
           Journal of The American Statistical Association, vol. 53, pp. 457-481, 1958.
    .. [2] S. Pölsterl, “scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn,”
           Journal of Machine Learning Research, vol. 21, no. 212, pp. 1–6, 2020.
    """
    # similar approach to sksurv, but no loops
    # even and censored is other way round in sksurv ->clarify
    #time, event = transform_back(y)
    # order, remove later
    is_sorted = lambda a: np.all(a[:-1] <= a[1:])
    
    if is_sorted(time) == False:
        order = np.argsort(time, kind="mergesort")
        time = time[order]
        event = event[order]
    
    times = np.unique(time)
    idx = np.digitize(time, np.unique(time))
    # numpy diff nth discrete difference over index, add 1 at the beginning
    breaks = np.flatnonzero(np.concatenate(([1], np.diff(idx))))

    # flatnonzero return indices that are nonzero in flattened version
    n_events = np.add.reduceat(event, breaks, axis=0)
    n_at_risk = np.sum(np.unique((np.outer(time,time)>=np.square(time)).astype(int).T,axis=0),axis=1)[::-1]
    print(n_at_risk)
    
    # censoring distribution for ipcw estimation
    #n_censored a vector, with 1 at censoring position, zero elsewhere
    n_events = n_events.astype(np.float32)
    n_at_risk = n_at_risk.astype(np.float32)
    if cens_dist:
        n_at_risk -= n_events
        # for each unique time step how many observations are censored
        censored = 1-event
        n_censored = np.add.reduceat(censored, breaks, axis=0)
        mask = (n_events != 0)
        vals = 1-np.divide(
        n_censored, n_at_risk,
        out=np.zeros(times.shape[0], dtype=float),
        where=n_censored != 0,
    )
        estimates = np.cumprod(vals)
        return times, estimates, n_censored


    else:
        vals = 1-np.divide(
        n_events, n_at_risk,
        out=np.zeros(times.shape[0], dtype=float),
        where=n_events != 0,
        )
        estimates = np.cumprod(vals)
        return times, estimates

In [324]:
KaplanMeier(time_np, event_np, cens_dist=True)

[8 5 4 1]


(array([1., 3., 5., 8.]), array([1. , 1. , 0.5, 0.5]), array([0., 0., 1., 0.]))

In [325]:
time_np

array([1., 1., 1., 3., 5., 5., 5., 8.])

In [235]:
def KaplanMeier_torch(time: np.array, event: np.array, 
                cens_dist: bool = False
) -> tuple[np.array, np.array] | tuple[np.array, np.array, np.array]:
    """_summary_

    Parameters
    ----------
    time : npt.NDArray[float]
        _description_
    event : npt.NDArray[int]
        _description_
    cens_dist : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[npt.NDArray[float], npt.NDArray[float]] | tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int]]
        _description_
    
    References
    ----------
    .. [1] Kaplan, E. L. and Meier, P., "Nonparametric estimation from incomplete observations",
           Journal of The American Statistical Association, vol. 53, pp. 457-481, 1958.
    .. [2] S. Pölsterl, “scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn,”
           Journal of Machine Learning Research, vol. 21, no. 212, pp. 1–6, 2020.
    """
    # similar approach to sksurv, but no loops
    # even and censored is other way round in sksurv ->clarify
    #time, event = transform_back(y)
    # order, remove later
    #is_sorted = lambda a: np.all(a[:-1] <= a[1:])
    
    # if is_sorted(time) == False:
    #     order = np.argsort(time, kind="mergesort")
    #     time = time[order]
    #     event = event[order]
    
    times = torch.unique(time)
    #idx = np.digitize(time, np.unique(time))
    # numpy diff nth discrete difference over index, add 1 at the beginning
    #breaks = np.flatnonzero(np.concatenate(([1], np.diff(idx))))

    # flatnonzero return indices that are nonzero in flattened version
    #n_events = np.add.reduceat(event, breaks, axis=0)
    unique_times, inverse_indices = time.unique(return_inverse=True)

    # Prepare a tensor for the event counts
    event_counts = torch.zeros_like(unique_times)

    # Add up the events for each unique time using scatter_add_
    event_counts.scatter_add_(0, inverse_indices, event)
    n_events = event_counts
    n_at_risk = torch.unique((torch.outer(time,time)>=np.square(time)).int().T, dim=0).sum(axis=1).flip(0)
    print('n_at_risk', n_at_risk)
    n_at_risk = n_at_risk.float()
    n_events = n_events.float()
    # censoring distribution for ipcw estimation
    #n_censored a vector, with 1 at censoring position, zero elsewhere
    print('n_at_risk shape',n_at_risk.shape)
    print('n_events shape',n_events.shape)
    if cens_dist:
        n_at_risk -= n_events
        # for each unique time step how many observations are censored
        censored = 1-event
        n_censored = torch.zeros_like(unique_times)
        #n_censored = np.add.reduceat(censored, breaks, axis=0)
        n_censored.scatter_add_(0, inverse_indices, censored)
        mask = (n_censored != 0)
        c = torch.zeros_like(times)
        # apply the division operation only where mask is True
        c[mask] = n_censored[mask] / n_at_risk[mask] 
        vals = 1-c
        estimates = torch.cumprod(vals, dim=0)
        return times, estimates, n_censored


    else:
        mask = (n_events != 0)
        vals = 1-torch.divide(
        n_events[mask], n_at_risk[mask],
        #rounding_mode=None,
        out=torch.zeros(times.shape[0]),
        #where=mask,
        )
        print(vals)
        estimates = torch.cumprod(vals, dim=0)
        return times, estimates

In [236]:
#KaplanMeier(time_np, event_np)

In [237]:
#KaplanMeier_torch(time_torch, event_torch)

In [238]:
KaplanMeier(time_np.astype(np.float32), event_np.astype(np.float32),cens_dist=True)

(array([1., 3., 5., 8.], dtype=float32),
 array([1. , 1. , 0.5, 0.5]),
 array([0., 0., 1., 0.], dtype=float32))

In [239]:
KaplanMeier_torch(time_torch, event_torch,cens_dist=True)

n_at_risk tensor([8, 5, 4, 1])
n_at_risk shape torch.Size([4])
n_events shape torch.Size([4])


(tensor([1., 3., 5., 8.], dtype=torch.float64),
 tensor([1.0000, 1.0000, 0.5000, 0.5000], dtype=torch.float64),
 tensor([0., 0., 1., 0.], dtype=torch.float64))

In [240]:
## IPCW Comparison

In [241]:
def ipcw_estimate(time: np.array, event: np.array) -> tuple[np.array, np.array]:

    unique_time, cens_dist, n_censored = KaplanMeier(time, event, cens_dist=True) 
    #print(cens_dist)
    # similar approach to sksurv
    idx = np.searchsorted(unique_time, time)
    est = 1.0/cens_dist[idx] # improve as divide by zero
    est[n_censored[idx]!=0] = 0
    # in R mboost there is a maxweight of 5
    est[est>5] = 5
    return unique_time, est

In [242]:
def ipcw_estimate_torch(time: np.array, event: np.array) -> tuple[np.array, np.array]:

    print(time.shape, event.shape)
    unique_time, cens_dist, n_censored = KaplanMeier_torch(time, event, cens_dist=True) 
    #print(cens_dist)
    # similar approach to sksurv
    idx = torch.searchsorted(unique_time, time)
    est = 1.0/cens_dist[idx] # improve as divide by zero
    est[n_censored[idx]!=0] = 0
    # in R mboost there is a maxweight of 5
    est[est>5] = 5
    return unique_time, est

In [243]:
ipcw_estimate(time_np, event_np)

(array([1., 3., 5., 8.]), array([1., 1., 1., 1., 0., 0., 0., 2.]))

In [244]:
ipcw_estimate_torch(time_torch, event_torch)

torch.Size([8]) torch.Size([8])
n_at_risk tensor([8, 5, 4, 1])
n_at_risk shape torch.Size([4])
n_events shape torch.Size([4])


(tensor([1., 3., 5., 8.], dtype=torch.float64),
 tensor([1., 1., 1., 1., 0., 0., 0., 2.], dtype=torch.float64))

## Loss Comparison


In [245]:
def transform_torch(time: torch.Tensor, event: torch.Tensor) -> torch.Tensor:
    """Transforms time, event into XGBoost digestable format.

    Parameters
    ----------
    time : npt.NDArray[float]
        Survival time.
    event : npt.NDArray[int]
        Boolean event indicator. Zero value is taken as censored event.

    Returns
    -------
    y : npt.NDArray[float]
        Transformed array containing survival time and event where negative value is taken as censored event.
    """
    #if isinstance(time, pd.Series):
    #    time = time.to_numpy()
    #    event = event.to_numpy()
    event_mod = event.clone()
    event_mod[event_mod==0] = -1
    if (time==0).any():
        raise RuntimeError('Data contains zero time value!')
        # alternative: time[time==0] = np.finfo(float).eps
    y = event_mod*time
    return y.to(torch.float32)


def transform_back_torch(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time.to(torch.float32), event.to(torch.float32)

In [309]:
def compute_weights(y, approach: str='paper') -> np.array:
    """_summary_

    Parameters
    ----------
    y : npt.NDArray[float]
        Sorted array containing survival time and event where negative value is taken as censored event.
    approach : str, optional
        Choose mboost implementation or paper implementation of c-boosting, by default 'paper'.

    Returns
    -------
    npt.NDArray[float]
        Array of weights.

    References
    ----------
    .. [1] 1. Mayr, A. & Schmid, M. Boosting the concordance index for survival data–a unified framework to derive and evaluate biomarker combinations. 
       PloS one 9, e84483 (2014).

    """
    time, event = transform_back(y) 
    n = event.shape[0]

    _, ipcw_new = ipcw_estimate(time, event)

    ipcw = ipcw_new #ipcw_old
    survtime = time
    wweights = np.full((n,n), np.square(ipcw)).T # good here

    weightsj = np.full((n,n), survtime).T

    weightsk = np.full((n,n), survtime) #byrow = TRUE in R, in np automatic no T required

    if approach == 'mboost':
        # implementing   weightsI <- ifelse(weightsj == weightsk, .5, (weightsj < weightsk) + 0) - diag(.5, n,n)
        # from mboost github repo
        weightsI = np.empty((n,n))
        weightsI[weightsj == weightsk] = 0.5
        weightsI = (weightsj < weightsk).astype(int)
        weightsI = weightsI - np.diag(0.5*np.ones(n))
    if approach == 'paper':
        weightsI = (weightsj < weightsk).astype(int) 

    wweights = wweights * weightsI 
    
    wweights = wweights / np.sum(wweights)

    return wweights

In [310]:
def compute_weights_torch(y, approach: str='paper') -> np.array:
    """_summary_

    Parameters
    ----------
    y : npt.NDArray[float]
        Sorted array containing survival time and event where negative value is taken as censored event.
    approach : str, optional
        Choose mboost implementation or paper implementation of c-boosting, by default 'paper'.

    Returns
    -------
    npt.NDArray[float]
        Array of weights.

    References
    ----------
    .. [1] 1. Mayr, A. & Schmid, M. Boosting the concordance index for survival data–a unified framework to derive and evaluate biomarker combinations. 
       PloS one 9, e84483 (2014).

    """
    time, event = transform_back_torch(y) 
    print('time shape', time.shape)
    print('event shape', event.shape)
    n = event.shape[0]

    _, ipcw_new = ipcw_estimate_torch(time, event)

    ipcw = ipcw_new #ipcw_old consider copy
    survtime = time

    fill = np.square(ipcw)
    wweights = fill.repeat(n,1).T
    # wweights = torch.full((n,n), np.square(ipcw)).T # good here

    # weightsj = torch.full((n,n), survtime).T
    weightsj = survtime.repeat(n,1).T
    # weightsk = torch.full((n,n), survtime) #byrow = TRUE in R, in np automatic no T required
    weightsk = survtime.repeat(n,1)
    
    if approach == 'mboost':
        # implementing   weightsI <- ifelse(weightsj == weightsk, .5, (weightsj < weightsk) + 0) - diag(.5, n,n)
        # from mboost github repo
        weightsI = torch.empty((n,n))
        weightsI[weightsj == weightsk] = 0.5
        weightsI = (weightsj < weightsk).astype(int)
        weightsI = weightsI - torch.diag(0.5*np.ones(n))
    if approach == 'paper':
        weightsI = (weightsj < weightsk).int()

    wweights = wweights * weightsI 
    
    wweights = wweights / torch.sum(wweights)

    return wweights

In [311]:
import pandas as pd
from xgbsurv.models.utils import transform
predictor = log_hazard = np.random.normal(0, 1, 1000)
df = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_1000.csv')
y = transform(df.time.to_numpy(), df.event.to_numpy())
y_torch = torch.tensor(y)
predictor_torch = torch.tensor(predictor)

In [312]:
wweights0 = compute_weights(y, approach='paper')

  est = 1.0/cens_dist[idx] # improve as divide by zero


In [313]:
wweights1 = compute_weights_torch(y_torch, approach='paper')

time shape torch.Size([1000])
event shape torch.Size([1000])
torch.Size([1000]) torch.Size([1000])
n_at_risk tensor([1000,  999,  998,  997,  996,  995,  994,  993,  992,  991,  990,  989,
         988,  987,  986,  985,  984,  983,  982,  981,  980,  979,  978,  977,
         976,  975,  974,  973,  972,  971,  970,  969,  968,  967,  966,  965,
         964,  963,  962,  961,  960,  959,  958,  957,  956,  955,  954,  953,
         952,  951,  950,  949,  948,  947,  946,  945,  944,  943,  942,  941,
         940,  939,  938,  937,  936,  935,  934,  933,  932,  931,  930,  929,
         928,  927,  926,  925,  924,  923,  922,  921,  920,  919,  918,  917,
         916,  915,  914,  913,  912,  911,  910,  909,  908,  907,  906,  905,
         904,  903,  902,  901,  900,  899,  898,  897,  896,  895,  894,  893,
         892,  891,  890,  889,  888,  887,  886,  885,  884,  883,  882,  881,
         880,  879,  878,  877,  876,  875,  874,  873,  872,  871,  870,  869,
         86

In [314]:
np.allclose(wweights0, wweights1)

True

In [315]:
from xgbsurv.models.utils import transform_back

def cind_loss(y, predictor, sigma = 0.1) ->np.array:
    # f corresponds to predictor in paper
    time, _ = transform_back(y)
    n = time.shape[0]
    etaj = np.full((n,n), predictor)
    etak = np.full((n,n), predictor).T
    x = (etak - etaj) 
    weights_out = compute_weights(y)
    c_loss = 1/(1+np.exp(x/sigma))*weights_out
    return -np.sum(c_loss)

In [318]:
def cind_loss_torch(y: np.array, predictor: np.array, sigma: np.array = 0.1) -> np.array:
    # f corresponds to predictor in paper
    time, _ = transform_back_torch(y)
    n = time.shape[0]
    #etaj = np.full((n,n), predictor)
    etaj = predictor.repeat(n,1)
    #etak = np.full((n,n), predictor).T
    etak = predictor.repeat(n,1).T
    x = (etak - etaj) 
    weights_out = compute_weights_torch(y)
    c_loss = 1/(1+torch.exp(x/sigma))*weights_out
    return -torch.sum(c_loss)

In [319]:
cind_loss(y, predictor, sigma = 0.1)

  est = 1.0/cens_dist[idx] # improve as divide by zero


-0.5122171753396544

In [321]:
cind_loss_torch(y_torch, predictor_torch, sigma = 0.1)

time shape torch.Size([1000])
event shape torch.Size([1000])
torch.Size([1000]) torch.Size([1000])
n_at_risk tensor([1000,  999,  998,  997,  996,  995,  994,  993,  992,  991,  990,  989,
         988,  987,  986,  985,  984,  983,  982,  981,  980,  979,  978,  977,
         976,  975,  974,  973,  972,  971,  970,  969,  968,  967,  966,  965,
         964,  963,  962,  961,  960,  959,  958,  957,  956,  955,  954,  953,
         952,  951,  950,  949,  948,  947,  946,  945,  944,  943,  942,  941,
         940,  939,  938,  937,  936,  935,  934,  933,  932,  931,  930,  929,
         928,  927,  926,  925,  924,  923,  922,  921,  920,  919,  918,  917,
         916,  915,  914,  913,  912,  911,  910,  909,  908,  907,  906,  905,
         904,  903,  902,  901,  900,  899,  898,  897,  896,  895,  894,  893,
         892,  891,  890,  889,  888,  887,  886,  885,  884,  883,  882,  881,
         880,  879,  878,  877,  876,  875,  874,  873,  872,  871,  870,  869,
         86

tensor(-0.5122, dtype=torch.float64)