In [31]:
import numpy as np
import torch
from xgbsurv.models.utils import transform_back



In [32]:
time_np = np.array([1.0,1.0,1.0,3.0,5.0,5.0,5.0,8.0])
event_np = np.array([1.0,1.0,1.0,1.0,1.0,1.0,0.0,1.0])
time_torch = torch.tensor(time_np)
event_torch = torch.tensor(event_np)

In [33]:
def KaplanMeier(time: np.array, event: np.array, 
                cens_dist: bool = False
) -> tuple[np.array, np.array] | tuple[np.array, np.array, np.array]:
    """_summary_

    Parameters
    ----------
    time : npt.NDArray[float]
        _description_
    event : npt.NDArray[int]
        _description_
    cens_dist : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[npt.NDArray[float], npt.NDArray[float]] | tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int]]
        _description_
    
    References
    ----------
    .. [1] Kaplan, E. L. and Meier, P., "Nonparametric estimation from incomplete observations",
           Journal of The American Statistical Association, vol. 53, pp. 457-481, 1958.
    .. [2] S. Pölsterl, “scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn,”
           Journal of Machine Learning Research, vol. 21, no. 212, pp. 1–6, 2020.
    """
    # similar approach to sksurv, but no loops
    # even and censored is other way round in sksurv ->clarify
    #time, event = transform_back(y)
    # order, remove later
    is_sorted = lambda a: np.all(a[:-1] <= a[1:])
    
    if is_sorted(time) == False:
        order = np.argsort(time, kind="mergesort")
        time = time[order]
        event = event[order]
    
    times = np.unique(time)
    idx = np.digitize(time, np.unique(time))
    # numpy diff nth discrete difference over index, add 1 at the beginning
    breaks = np.flatnonzero(np.concatenate(([1], np.diff(idx))))

    # flatnonzero return indices that are nonzero in flattened version
    n_events = np.add.reduceat(event, breaks, axis=0)
    n_at_risk = np.sum(np.unique((np.outer(time,time)>=np.square(time)).astype(int).T,axis=0),axis=1)[::-1]
    
    # censoring distribution for ipcw estimation
    #n_censored a vector, with 1 at censoring position, zero elsewhere
    n_events = n_events.astype(np.float32)
    n_at_risk = n_at_risk.astype(np.float32)
    if cens_dist:
        n_at_risk -= n_events
        # for each unique time step how many observations are censored
        censored = 1-event
        n_censored = np.add.reduceat(censored, breaks, axis=0)
        mask = (n_events != 0)
        vals = 1-np.divide(
        n_censored, n_at_risk,
        out=np.zeros(times.shape[0], dtype=float),
        where=n_censored != 0,
    )
        estimates = np.cumprod(vals)
        return times, estimates, n_censored


    else:
        vals = 1-np.divide(
        n_events, n_at_risk,
        out=np.zeros(times.shape[0], dtype=float),
        where=n_events != 0,
        )
        estimates = np.cumprod(vals)
        return times, estimates

In [34]:
KaplanMeier(time_np, event_np, cens_dist=True)

(array([1., 3., 5., 8.]), array([1. , 1. , 0.5, 0.5]), array([0., 0., 1., 0.]))

In [35]:
time_np

array([1., 1., 1., 3., 5., 5., 5., 8.])

In [36]:
def KaplanMeier_torch(time: np.array, event: np.array, 
                cens_dist: bool = False
) -> tuple[np.array, np.array] | tuple[np.array, np.array, np.array]:
    """_summary_

    Parameters
    ----------
    time : npt.NDArray[float]
        _description_
    event : npt.NDArray[int]
        _description_
    cens_dist : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[npt.NDArray[float], npt.NDArray[float]] | tuple[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int]]
        _description_
    
    References
    ----------
    .. [1] Kaplan, E. L. and Meier, P., "Nonparametric estimation from incomplete observations",
           Journal of The American Statistical Association, vol. 53, pp. 457-481, 1958.
    .. [2] S. Pölsterl, “scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn,”
           Journal of Machine Learning Research, vol. 21, no. 212, pp. 1–6, 2020.
    """
    # similar approach to sksurv, but no loops
    # even and censored is other way round in sksurv ->clarify
    #time, event = transform_back(y)
    # order, remove later
    #is_sorted = lambda a: np.all(a[:-1] <= a[1:])
    
    # if is_sorted(time) == False:
    #     order = np.argsort(time, kind="mergesort")
    #     time = time[order]
    #     event = event[order]
    
    times = torch.unique(time)
    #idx = np.digitize(time, np.unique(time))
    # numpy diff nth discrete difference over index, add 1 at the beginning
    #breaks = np.flatnonzero(np.concatenate(([1], np.diff(idx))))

    # flatnonzero return indices that are nonzero in flattened version
    #n_events = np.add.reduceat(event, breaks, axis=0)
    unique_times, inverse_indices = time.unique(return_inverse=True)

    # Prepare a tensor for the event counts
    event_counts = torch.zeros_like(unique_times)

    # Add up the events for each unique time using scatter_add_
    event_counts.scatter_add_(0, inverse_indices, event)
    n_events = event_counts
    n_at_risk = torch.unique((torch.outer(time,time)>=torch.square(time)).int().T, dim=0).sum(axis=1).flip(0)
    #print('n_at_risk', n_at_risk)
    n_at_risk = n_at_risk.float()
    n_events = n_events.float()
    # censoring distribution for ipcw estimation
    #n_censored a vector, with 1 at censoring position, zero elsewhere
    #print('n_at_risk shape',n_at_risk.shape)
    #print('n_events shape',n_events.shape)
    if cens_dist:
        n_at_risk -= n_events
        # for each unique time step how many observations are censored
        censored = 1-event
        n_censored = torch.zeros_like(unique_times)
        #n_censored = np.add.reduceat(censored, breaks, axis=0)
        n_censored.scatter_add_(0, inverse_indices, censored)
        mask = (n_censored != 0)
        c = torch.zeros_like(times)
        # apply the division operation only where mask is True
        c[mask] = n_censored[mask] / n_at_risk[mask] 
        vals = 1-c
        estimates = torch.cumprod(vals, dim=0)
        return times, estimates, n_censored


    else:
        mask = (n_events != 0)
        vals = 1-torch.divide(
        n_events[mask], n_at_risk[mask],
        #rounding_mode=None,
        out=torch.zeros(times.shape[0]),
        #where=mask,
        )
        #print(vals)
        estimates = torch.cumprod(vals, dim=0)
        return times, estimates

In [37]:
#KaplanMeier(time_np, event_np)

In [38]:
#KaplanMeier_torch(time_torch, event_torch)

In [39]:
KaplanMeier(time_np.astype(np.float32), event_np.astype(np.float32),cens_dist=True)

(array([1., 3., 5., 8.], dtype=float32),
 array([1. , 1. , 0.5, 0.5]),
 array([0., 0., 1., 0.], dtype=float32))

In [40]:
KaplanMeier_torch(time_torch, event_torch,cens_dist=True)

(tensor([1., 3., 5., 8.], dtype=torch.float64),
 tensor([1.0000, 1.0000, 0.5000, 0.5000], dtype=torch.float64),
 tensor([0., 0., 1., 0.], dtype=torch.float64))

In [41]:
## IPCW Comparison

In [42]:
def ipcw_estimate(time: np.array, event: np.array) -> tuple[np.array, np.array]:

    unique_time, cens_dist, n_censored = KaplanMeier(time, event, cens_dist=True) 
    #print(cens_dist)
    # similar approach to sksurv
    idx = np.searchsorted(unique_time, time)
    est = 1.0/cens_dist[idx] # improve as divide by zero
    est[n_censored[idx]!=0] = 0
    # in R mboost there is a maxweight of 5
    est[est>5] = 5
    return unique_time, est

In [43]:
def ipcw_estimate_torch(time: np.array, event: np.array) -> tuple[np.array, np.array]:

    #print(time.shape, event.shape)
    unique_time, cens_dist, n_censored = KaplanMeier_torch(time, event, cens_dist=True) 
    #print(cens_dist)
    # similar approach to sksurv
    idx = torch.searchsorted(unique_time, time)

    mask1 = cens_dist[idx] != 0.0

    est = torch.ones_like(cens_dist[idx])
    est[mask1] = 1.0/cens_dist[idx][mask1]
    #est = 1.0/cens_dist[idx] # improve as divide by zero

    est[n_censored[idx]!=0] = 0
    # in R mboost there is a maxweight of 5
    est[est>5] = 5
    return unique_time, est

In [44]:
ipcw_estimate(time_np, event_np)

(array([1., 3., 5., 8.]), array([1., 1., 1., 1., 0., 0., 0., 2.]))

In [45]:
ipcw_estimate_torch(time_torch, event_torch)

(tensor([1., 3., 5., 8.], dtype=torch.float64),
 tensor([1., 1., 1., 1., 0., 0., 0., 2.], dtype=torch.float64))

## Loss Comparison


In [46]:
def transform_torch(time: torch.Tensor, event: torch.Tensor) -> torch.Tensor:
    """Transforms time, event into XGBoost digestable format.

    Parameters
    ----------
    time : npt.NDArray[float]
        Survival time.
    event : npt.NDArray[int]
        Boolean event indicator. Zero value is taken as censored event.

    Returns
    -------
    y : npt.NDArray[float]
        Transformed array containing survival time and event where negative value is taken as censored event.
    """
    #if isinstance(time, pd.Series):
    #    time = time.to_numpy()
    #    event = event.to_numpy()
    event_mod = event.clone()
    event_mod[event_mod==0] = -1
    if (time==0).any():
        raise RuntimeError('Data contains zero time value!')
        # alternative: time[time==0] = np.finfo(float).eps
    y = event_mod*time
    return y.to(torch.float32)


def transform_back_torch(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time.to(torch.float32), event.to(torch.float32)

In [47]:
def compute_weights(y, approach: str='paper') -> np.array:
    """_summary_

    Parameters
    ----------
    y : npt.NDArray[float]
        Sorted array containing survival time and event where negative value is taken as censored event.
    approach : str, optional
        Choose mboost implementation or paper implementation of c-boosting, by default 'paper'.

    Returns
    -------
    npt.NDArray[float]
        Array of weights.

    References
    ----------
    .. [1] 1. Mayr, A. & Schmid, M. Boosting the concordance index for survival data–a unified framework to derive and evaluate biomarker combinations. 
       PloS one 9, e84483 (2014).

    """
    time, event = transform_back(y) 
    n = event.shape[0]

    _, ipcw_new = ipcw_estimate(time, event)

    ipcw = ipcw_new #ipcw_old
    survtime = time
    wweights = np.full((n,n), np.square(ipcw)).T # good here

    weightsj = np.full((n,n), survtime).T

    weightsk = np.full((n,n), survtime) #byrow = TRUE in R, in np automatic no T required

    if approach == 'mboost':
        # implementing   weightsI <- ifelse(weightsj == weightsk, .5, (weightsj < weightsk) + 0) - diag(.5, n,n)
        # from mboost github repo
        weightsI = np.empty((n,n))
        weightsI[weightsj == weightsk] = 0.5
        weightsI = (weightsj < weightsk).astype(int)
        weightsI = weightsI - np.diag(0.5*np.ones(n))
    if approach == 'paper':
        weightsI = (weightsj < weightsk).astype(int) 

    wweights = wweights * weightsI 
    
    wweights = wweights / np.sum(wweights)

    return wweights

In [48]:
def compute_weights_torch(y, approach: str='paper') -> np.array:
    """_summary_

    Parameters
    ----------
    y : npt.NDArray[float]
        Sorted array containing survival time and event where negative value is taken as censored event.
    approach : str, optional
        Choose mboost implementation or paper implementation of c-boosting, by default 'paper'.

    Returns
    -------
    npt.NDArray[float]
        Array of weights.

    References
    ----------
    .. [1] 1. Mayr, A. & Schmid, M. Boosting the concordance index for survival data–a unified framework to derive and evaluate biomarker combinations. 
       PloS one 9, e84483 (2014).

    """
    time, event = transform_back_torch(y) 
    #print('time shape', time.shape)
    #print('event shape', event.shape)
    n = event.shape[0]

    _, ipcw_new = ipcw_estimate_torch(time, event)

    ipcw = ipcw_new #ipcw_old consider copy
    #survtime = time

    fill = torch.square(ipcw)
    #fill.repeat(n,1).T
    wweights = fill.unsqueeze(1).expand(-1, n)

    #survtime.repeat(n,1).T
    weightsj = time.unsqueeze(1).expand(-1, n)
   
    #survtime.repeat(n,1)
    weightsk = time.unsqueeze(1).expand(-1, n).T
    
    if approach == 'mboost':
        # implementing   weightsI <- ifelse(weightsj == weightsk, .5, (weightsj < weightsk) + 0) - diag(.5, n,n)
        # from mboost github repo
        weightsI = torch.empty((n,n))
        weightsI[weightsj == weightsk] = 0.5
        weightsI = (weightsj < weightsk).astype(int)
        weightsI = weightsI - torch.diag(0.5*np.ones(n))
    if approach == 'paper':
        weightsI = (weightsj < weightsk).int()

    wweights = wweights * weightsI 
    del weightsI, weightsk, weightsj
    wweights = wweights / torch.sum(wweights)

    return wweights

In [49]:
import pandas as pd
from xgbsurv.models.utils import transform
predictor = log_hazard = np.random.normal(0, 1, 1000)
df = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_1000.csv')
y = transform(df.time.to_numpy(), df.event.to_numpy())
y_torch = torch.tensor(y)
predictor_torch = torch.tensor(predictor)

In [50]:
wweights0 = compute_weights(y, approach='paper')

In [51]:
wweights1 = compute_weights_torch(y_torch, approach='paper')

In [52]:
np.allclose(wweights0, wweights1)

True

In [53]:
from xgbsurv.models.utils import transform_back

def cind_loss(y, predictor, sigma = 0.1) ->np.array:
    # f corresponds to predictor in paper
    time, _ = transform_back(y)
    n = time.shape[0]
    etaj = np.full((n,n), predictor)
    etak = np.full((n,n), predictor).T
    x = (etak - etaj) 
    weights_out = compute_weights(y)
    c_loss = 1/(1+np.exp(x/sigma))*weights_out
    return -np.sum(c_loss)

In [54]:
def cind_loss_torch(y: np.array, predictor: np.array, sigma: np.array = 0.1) -> np.array:
    # f corresponds to predictor in paper
    time, _ = transform_back_torch(y)
    n = time.shape[0]

    #predictor.repeat(n,1)
    etaj = predictor.unsqueeze(1).expand(-1, n).T 
    #predictor.repeat(n,1).T
    etak = predictor.unsqueeze(1).expand(-1, n) 
    x = (etak - etaj) 
    weights_out = compute_weights_torch(y)
    print('weights_out.shape',weights_out.shape)
    c_loss = 1/(1+torch.exp(x/sigma))*weights_out
    return -torch.sum(c_loss)

In [55]:
np_loss = cind_loss(y, predictor, sigma = 0.1)

In [56]:
torch_loss = cind_loss_torch(y_torch,predictor_torch, sigma = 0.1)

weights_out.shape torch.Size([1000, 1000])


In [57]:
np.allclose(np_loss, torch_loss)

True

In [58]:
print(np_loss, torch_loss)

-0.5114959825198104 tensor(-0.5115, dtype=torch.float64)


In [59]:
import sys
sys.path.append('/Users/JUSC/Documents/xgbsurv/experiments/deep_learning')
from loss_functions_pytorch import cind_likelihood_torch
predictor_torch = torch.tensor(predictor, requires_grad=True)
time_torch = torch.tensor(time_np, requires_grad=True)
event_torch = torch.tensor(event_np, requires_grad=True)
d = cind_likelihood_torch(predictor_torch, y_torch,  sigma = 0.1) # different order
print(d)
d.backward()
predictor_torch.grad

weights_out.shape torch.Size([1000, 1000])
tensor(-0.5115, dtype=torch.float64, grad_fn=<NegBackward0>)


tensor([-6.4807e-04, -1.1727e-04, -2.2686e-04,  4.7333e-04,  7.6656e-04,
        -3.9806e-04,  1.6378e-04,  1.1481e-03, -2.4710e-04,  6.2454e-04,
         4.3906e-04,  6.8778e-04, -4.0097e-04, -2.5393e-04, -3.7270e-04,
        -5.8260e-04,  3.3661e-04,  3.2856e-04,  5.0804e-04,  3.2912e-04,
        -1.3909e-04, -3.1582e-04, -4.6944e-04, -5.6578e-04, -4.8168e-04,
         4.5051e-04, -4.5098e-04, -2.5004e-04, -3.7393e-04,  7.2480e-04,
        -3.2129e-04, -5.2864e-04,  9.9898e-04, -4.5997e-04, -5.0943e-04,
        -1.5171e-04, -2.7090e-04, -4.0733e-04,  1.0960e-03, -4.5737e-04,
         6.6007e-04,  1.2874e-03,  3.7944e-04,  8.0060e-04, -3.8922e-04,
        -2.0678e-04, -2.1773e-04,  8.2418e-04, -3.7993e-04,  8.9226e-04,
         3.7400e-04, -4.4584e-04, -5.9111e-04,  3.4975e-04, -3.4092e-04,
        -2.3846e-04, -4.7695e-04,  1.1400e-03,  6.4417e-04,  4.3762e-04,
         6.7918e-04, -4.6274e-04, -4.6625e-04, -4.5748e-04,  5.7386e-04,
        -2.4963e-04, -2.8732e-04, -6.4055e-04, -4.4