In [1]:
import numpy as np
from xgbsurv.models.utils import transform, transform_back
from xgbsurv.models.eh_final import eh_likelihood, eh_gradient
from xgbsurv.models.eh_ah_final import ah_likelihood, ah_objective
from xgbsurv.models.utils import sort_X_y, transform_back
import sys
#from loss_functions_pytorch import deephit_likelihood_1_torch
import torch
import math
torch.set_printoptions(precision=10)
from torch.autograd.functional import hessian
from xgbsurv.datasets import load_metabric

In [2]:
def transform_back_torch(y: torch.torch.torch.torch.torch.torch.torch.torch.torch.tensor) -> tuple[torch.torch.tensor, torch.torch.tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time.to(torch.float32), event.to(torch.float32)

def transform_back_torch_deephit(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    #TODO: Build conditions and combine transform functions
    y = y[:,0]
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time.to(torch.float32), event.to(torch.float32)

## Load Data

In [3]:
phi = torch.torch.tensor([[0.56061679, 0.67018677, 0.830002  , 0.54249826, 0.06989779,
        0.37876414, 0.03824207, 0.39794583, 0.3478378 , 0.46833452,
        0.32292692, 0.70704927, 0.56557495, 0.58150488, 0.38871381,
        0.34848768, 0.39602859, 0.7649664 , 0.97185072],
       [0.95566696, 0.48104482, 0.75679969, 0.99386694, 0.91718885,
        0.49509238, 0.29452671, 0.29319748, 0.12167261, 0.34684764,
        0.20933687, 0.56328841, 0.04492321, 0.72173335, 0.79000414,
        0.50887818, 0.18630897, 0.53458442, 0.28050238],
       [0.15593853, 0.02782267, 0.01185559, 0.71230657, 0.81883045,
        0.4969608 , 0.74693123, 0.87947765, 0.31456862, 0.12872293,
        0.23845144, 0.64749992, 0.97677924, 0.57487316, 0.35761132,
        0.97724796, 0.8684253 , 0.68128472, 0.0869709 ],
       [0.99632921, 0.56681704, 0.95074056, 0.92120083, 0.76054228,
        0.27696215, 0.62860255, 0.70248436, 0.92183456, 0.70921057,
        0.76649902, 0.64107988, 0.24368451, 0.51152253, 0.80426748,
        0.74221419, 0.27668871, 0.17919406, 0.42808766],
       [0.27972894, 0.56866687, 0.6539161 , 0.5319978 , 0.99621816,
        0.33599839, 0.37512062, 0.14721801, 0.53298744, 0.50357236,
        0.38652909, 0.26062667, 0.77156399, 0.84954073, 0.68774531,
        0.47792277, 0.06867723, 0.76415557, 0.49115389],
       [0.79273291, 0.92759044, 0.81326807, 0.04226328, 0.98768654,
        0.07537023, 0.55331409, 0.22540424, 0.71398715, 0.7250366 ,
        0.32643635, 0.29281341, 0.46330328, 0.41149408, 0.13778508,
        0.03694444, 0.30238445, 0.44078415, 0.19661292],
       [0.75843055, 0.99731981, 0.94590159, 0.06033445, 0.20452619,
        0.76126846, 0.54285514, 0.7893368 , 0.38247081, 0.85430444,
        0.52043418, 0.18891051, 0.63758591, 0.37400696, 0.7804225 ,
        0.47204052, 0.56156778, 0.33566019, 0.3990119 ],
       [0.5306947 , 0.49757042, 0.69553678, 0.50663494, 0.44598667,
        0.21937834, 0.90704425, 0.65035354, 0.1341808 , 0.76539241,
        0.58519312, 0.98323395, 0.85668469, 0.02645247, 0.71319072,
        0.59216107, 0.21904393, 0.80283269, 0.38052726],
       [0.33536991, 0.56913116, 0.42259562, 0.68180137, 0.73365645,
        0.66044068, 0.97128302, 0.20019211, 0.73096269, 0.31609701,
        0.67428402, 0.48398141, 0.3843597 , 0.38371562, 0.36886752,
        0.0794369 , 0.03408211, 0.75023623, 0.68362319],
       [0.65184423, 0.48534751, 0.08064047, 0.78674504, 0.06255098,
        0.07844527, 0.196811  , 0.52218442, 0.81096049, 0.23212956,
        0.06336911, 0.14451082, 0.59268832, 0.44494552, 0.30831269,
        0.18509292, 0.7571426 , 0.78059063, 0.83033576],
       [0.1657391 , 0.98890883, 0.59125266, 0.80830763, 0.19778208,
        0.75399619, 0.71727879, 0.4851396 , 0.15589286, 0.73861217,
        0.511602  , 0.16323839, 0.66437389, 0.18004531, 0.81873861,
        0.07826153, 0.8947921 , 0.61635483, 0.43455631],
       [0.395005  , 0.97093084, 0.33062423, 0.23574601, 0.51486585,
        0.42856454, 0.94329152, 0.70344321, 0.57317513, 0.94632976,
        0.63672714, 0.99563958, 0.57327768, 0.03910801, 0.07771027,
        0.48566129, 0.85488311, 0.6057723 , 0.16950561],
       [0.59508459, 0.95045607, 0.92343267, 0.96820986, 0.06633364,
        0.83614061, 0.67577314, 0.74443865, 0.59352958, 0.51259308,
        0.03491607, 0.04434695, 0.85837897, 0.71040344, 0.02438701,
        0.99425333, 0.79017128, 0.12834676, 0.3988308 ],
       [0.68108959, 0.28834384, 0.81592558, 0.17091973, 0.21407878,
        0.23098784, 0.42014381, 0.85523219, 0.56613711, 0.61387742,
        0.55994703, 0.61032793, 0.65646759, 0.1812232 , 0.26260767,
        0.35116788, 0.64650932, 0.98678635, 0.52668915],
       [0.28992255, 0.40149962, 0.671999  , 0.87495982, 0.41206135,
        0.47103786, 0.83976069, 0.99689313, 0.53925724, 0.26320269,
        0.22911652, 0.97733889, 0.16904251, 0.42565559, 0.25421109,
        0.05060954, 0.51099776, 0.076741  , 0.71688287],
       [0.43050308, 0.08717614, 0.33684441, 0.7273216 , 0.48634413,
        0.7530504 , 0.52148584, 0.49021319, 0.23076744, 0.46344035,
        0.45206929, 0.56467496, 0.49647443, 0.4774734 , 0.90622309,
        0.57040588, 0.49141411, 0.28491756, 0.26073103],
       [0.44798068, 0.89615949, 0.30552967, 0.33349176, 0.95502655,
        0.91813918, 0.62428386, 0.84984507, 0.86836687, 0.78609441,
        0.47261591, 0.80318656, 0.13896811, 0.90963732, 0.05839439,
        0.83315826, 0.28216467, 0.1452794 , 0.70495897],
       [0.00463247, 0.51736001, 0.08070175, 0.11176686, 0.49550875,
        0.39094751, 0.65327167, 0.56788782, 0.73013999, 0.25124108,
        0.31757687, 0.58068344, 0.56020449, 0.69263596, 0.56677569,
        0.0296664 , 0.88003492, 0.21268497, 0.78623878],
       [0.3782029 , 0.81943188, 0.81412233, 0.52370451, 0.79590929,
        0.92546203, 0.59783995, 0.2973661 , 0.84875957, 0.64143599,
        0.45042338, 0.916713  , 0.39235494, 0.94565556, 0.82454687,
        0.89168315, 0.18688824, 0.97046952, 0.3591719 ],
       [0.88820103, 0.22469503, 0.09393979, 0.90702622, 0.2599165 ,
        0.38592207, 0.05689468, 0.74827846, 0.61047844, 0.97566285,
        0.75526081, 0.50752004, 0.94641458, 0.51176464, 0.42993403,
        0.86378897, 0.86438584, 0.12020468, 0.5326653 ]])


def deephit_data(phi=phi):
    nrows = 20
    data = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False)
    data, target = sort_X_y(data.data, data.target)
    data = data[:nrows,:]
    target = target[:nrows]
    data = torch.torch.tensor(data, dtype=torch.float32)
    time, event = transform_back(target)
    ncols = len(np.unique(np.abs(target)))
    target = np.tile(target, (ncols,1)).T
    #print(ncols)
    #phi = np.random.rand(20, ncols)
    target = torch.torch.tensor(target, dtype=torch.float32)
    time = torch.torch.tensor(time, dtype=torch.float32)
    event = torch.torch.tensor(event, dtype=torch.float32)
    print(type(target))
    return time, event, target, phi

In [4]:
x = np.array([0.2, 6.4, 3.0, 1.6])

bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0])

inds = np.digitize(x, bins)

inds

array([1, 4, 3, 2])

In [5]:
x = torch.torch.tensor([0.2, 6.4, 3.0, 1.6])

bins = torch.torch.tensor([0.0, 1.0, 2.5, 4.0, 10.0])

inds = torch.bucketize(x, bins)

inds

tensor([1, 4, 3, 2])

In [6]:
def pad_col(input, val=0, where='end'):
    """Addes a column of `val` at the start of end of `input`."""
    if len(input.shape) != 2:
        raise ValueError(f"Only works for `phi` torch.tensor that is 2-D.")
    pad = torch.zeros_like(input[:, :1])
    if val != 0:
        pad = pad + val
    if where == 'end':
        return torch.cat([input, pad], dim=1)
    elif where == 'start':
        return torch.cat([pad, input], dim=1)
    raise ValueError(f"Need `where` to be 'start' or 'end', got {where}")

def _reduction(loss: torch.tensor, reduction: str = 'mean') -> torch.tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")

def nll_pmf(phi: torch.tensor, idx_durations: torch.tensor, events: torch.tensor, reduction: str = 'mean',
            epsilon: float = 1e-7) -> torch.tensor:
    """Negative log-likelihood for the PMF parametrized model [1].
    
    Arguments:
        phi {torch.torch.tensor} -- Estimates in (-inf, inf), where pmf = somefunc(phi).
        idx_durations {torch.torch.tensor} -- Event times represented as indices.
        events {torch.torch.tensor} -- Indicator of event (1.) or censoring (0.).
            Same length as 'idx_durations'.
        reduction {string} -- How to reduce the loss.
            'none': No reduction.
            'mean': Mean of torch.tensor.
            'sum: sum.
    
    Returns:
        torch.torch.tensor -- The negative log-likelihood.

    References:
    [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
        with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
        https://arxiv.org/pdf/1910.06724.pdf
    """
    if phi.shape[1] <= idx_durations.max():
        raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
                         f" Need at least `phi.shape[1] = {idx_durations.max().item()+1}`,"+
                         f" but got `phi.shape[1] = {phi.shape[1]}`")
    if events.dtype is torch.bool:
        events = events.float()
    events = events.view(-1)
    idx_durations = idx_durations.view(-1, 1)
    phi = pad_col(phi)
    print('phi shape', phi.shape)
    gamma = phi.max(1)[0]
    cumsum = phi.sub(gamma.view(-1, 1)).exp().cumsum(1)
    sum_ = cumsum[:, -1]
    print('shapes', idx_durations.shape, phi.shape, gamma.shape, events.shape)
    part1 = phi.gather(1, idx_durations).view(-1).sub(gamma).mul(events)
    part2 = - sum_.relu().add(epsilon).log()
    part3 = sum_.sub(cumsum.gather(1, idx_durations).view(-1)).relu().add(epsilon).log().mul(1. - events)
    # need relu() in part3 (and possibly part2) because cumsum on gpu has some bugs and we risk getting negative numbers.
    loss = - part1.add(part2).add(part3)
    return _reduction(loss, reduction)


In [7]:
time, events, y, phi = deephit_data()

<class 'torch.Tensor'>


In [22]:
y.shape

torch.Size([20, 19])

In [8]:
bins = torch.unique(time)
idx_durations = (torch.bucketize(time, bins))
idx_durations

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10, 11, 12, 13, 14, 15, 16,
        17, 18])

In [9]:
nll_pmf(phi, idx_durations, events, reduction='sum')

phi shape torch.Size([20, 20])
shapes torch.Size([20, 1]) torch.Size([20, 20]) torch.Size([20]) torch.Size([20])


tensor(31.7249526978)

In [20]:

def deephit_likelihood_1_torch(y, phi):
    time, events = transform_back_torch_deephit(y)
    #time = time.reshape(time.shape[0],1)
    #print('shape idx dur befoer', idx_durations.shape)
    bins = torch.unique(time)
    idx_durations = (torch.bucketize(time, bins))
    idx_durations = idx_durations.view(-1, 1)
    #print('shape idx dur after', idx_durations.shape)
    # epsilon 
    epsilon = np.finfo(float).eps
    # pad phi as in pycox
    pad = torch.zeros_like(phi[:,:1])
    phi = torch.cat([phi, pad],axis=1)
    print('phi shape', phi.shape)
    # create durations index
    bins = torch.unique(time)

    if phi.shape[1] <= idx_durations.max():
        raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
                         f" Need at least `phi.shape[1] = {idx_durations.max().item()+1}`,"+
                         f" but got `phi.shape[1] = {phi.shape[1]}`")
    if events.dtype is torch.bool:
        events = events.float()
    #events = events.view(-1)
    #idx_durations = idx_durations.view(-1, 1)
    #phi = utils.pad_col(phi)

    gamma = phi.max(1)[0]
    print('shapes', idx_durations.shape, phi.shape, gamma.shape, events.shape)
    cumsum = phi.sub(gamma.view(-1, 1)).exp().cumsum(1)
    sum_ = cumsum[:, -1]
    part1 = phi.gather(1, idx_durations).view(-1).sub(gamma).mul(events)
    part2 = - sum_.relu().add(epsilon).log()
    part3 = sum_.sub(cumsum.gather(1, idx_durations).view(-1)).relu().add(epsilon).log().mul(1. - events)
    # need relu() in part3 (and possibly part2) because cumsum on gpu has some bugs and we risk getting negative numbers.
    loss = - part1.add(part2).add(part3)
    return torch.sum(loss)


In [21]:
deephit_likelihood_1_torch(y, phi)

phi shape torch.Size([20, 20])
shapes torch.Size([20, 1]) torch.Size([20, 20]) torch.Size([20]) torch.Size([20])


tensor(31.7249526978)

In [29]:
torch.unique(time)

tensor([0.1000000015, 0.7666666508, 1.2333333492, 1.2666666508, 1.4333332777,
        1.7666666508, 2.0000000000, 2.2999999523, 2.4000000954, 2.5000000000,
        2.5333333015, 3.3666665554, 3.5000000000, 3.7666666508, 4.1666665077,
        4.4333333969, 4.8666667938, 5.0666666031, 5.4333333969])

In [26]:

def deephit_likelihood_1_torch(y, phi):
    #time, events = transform_back_torch_deephit(y)
    time, events = transform_back_torch(y)
    #time = time.reshape(time.shape[0],1)
    #print('shape idx dur befoer', idx_durations.shape)
    bins = torch.unique(time)
    idx_durations = (torch.bucketize(time, bins))
    idx_durations = idx_durations.view(-1, 1)
    print('idx_durations',idx_durations)
    #print('shape idx dur after', idx_durations.shape)
    # epsilon 
    epsilon = np.finfo(float).eps
    # pad phi as in pycox
    pad = torch.zeros_like(phi[:,:1])
    phi = torch.cat([phi, pad],axis=1)
    print('phi shape', phi.shape)
    # create durations index
    bins = torch.unique(time)

    if phi.shape[1] <= idx_durations.max():
        raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
                         f" Need at least `phi.shape[1] = {idx_durations.max().item()+1}`,"+
                         f" but got `phi.shape[1] = {phi.shape[1]}`")
    if events.dtype is torch.bool:
        events = events.float()
    #events = events.view(-1)
    #idx_durations = idx_durations.view(-1, 1)
    #phi = utils.pad_col(phi)

    gamma = phi.max(1)[0]
    print('shapes', idx_durations.shape, phi.shape, gamma.shape, events.shape)
    cumsum = phi.sub(gamma.view(-1, 1)).exp().cumsum(1)
    sum_ = cumsum[:, -1]
    part1 = phi.gather(1, idx_durations).view(-1).sub(gamma).mul(events)
    part2 = - sum_.relu().add(epsilon).log()
    part3 = sum_.sub(cumsum.gather(1, idx_durations).view(-1)).relu().add(epsilon).log().mul(1. - events)
    # need relu() in part3 (and possibly part2) because cumsum on gpu has some bugs and we risk getting negative numbers.
    loss = - part1.add(part2).add(part3)
    return torch.sum(loss)


In [27]:
deephit_likelihood_1_torch(y[:,0], phi)

idx_durations tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18]])
phi shape torch.Size([20, 20])
shapes torch.Size([20, 1]) torch.Size([20, 20]) torch.Size([20]) torch.Size([20])


tensor(31.7249526978)

## 1. Compare loss to original function



In [None]:
# EH loss from paper

def eaftloss(out, time, delta): ##loss function for AFT or EH
    ia, ib = out.size()
    if ib == 1: ###loss function for AFT
        n = len(delta)
        #print(n)
        h = 1.30*math.pow(n,-0.2)
        #h 1.304058*math.pow(n,-0.2)  ## 1.304058*n^(-1/5) or 1.587401*math.pow(n,-0.333333) 1.587401*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(out,torch.log(time)) 
        
        # Rj - Ri
        rawones = torch.ones([1,n], dtype = out.dtype)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)     
        A = torch.t(delta)*log_Dk/n   
        S1 = A.sum()  
        
        ncdf=torch.distributions.normal.Normal(torch.torch.tensor([0.0], dtype = out.dtype), torch.torch.tensor([1.0], dtype = out.dtype)).cdf
        P = ncdf(DR/h)
        CDF_sum = torch.sum(P, dim=0)/n
        Q = torch.log(CDF_sum)
        S2 = -(delta*Q.view(n,1)).sum()/n
             
        S0 = -(delta*torch.log(time)).sum()/n
        
        S = S0 + S1 + S2 
        S = -S
    else: ### loss function for Extended hazard model
        n = len(out[:,0])
        h = 1.30*math.pow(n,-0.2)  ## or 1.59*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        g1 = out[:,0].view(n,1)
        g2 = out[:,1].view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(g1,torch.log(time)) 
        
        S1 =  (delta*g2).sum()/n
        S2 = -(delta*R).sum()/n
        
        # Rj - Ri
        rawones = torch.ones(1,n)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)  ## Dk would be zero as learning rate too large!
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)    
        
        S3 = (torch.t(delta)*log_Dk).sum()/n    
        
        # Phi((Rj-Ri)/h)
        ncdf=torch.distributions.normal.Normal(torch.torch.tensor([0.0]), torch.torch.tensor([1.0])).cdf
        P = ncdf(DR/h) 
        L = torch.exp(g2-g1)
        LL = torch.mm(L,rawones)
        LP_sum = torch.sum(LL*P, dim=0)/n
        Q = torch.log(LP_sum)
        
        S4 = -(delta*Q.view(n,1)).sum()/n
        
        S = S1 + S2 + S3 + S4  
        S = -S
    return S

def normal_density(a):  
    b = 0.3989423*torch.exp(-0.5*torch.pow(a,2.0))
    return b

### AFT Loss Original Paper

In [None]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_paper = eaftloss(linear_predictor, time, event)
print(eaftloss(linear_predictor, time, event))

tensor(1.5339863300, grad_fn=<NegBackward0>)


### AFT Loss Pytorch

check bandwidth, check dims vector for aft, go through code step by step

In [None]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_own_torch = aft_likelihood_torch(linear_predictor.reshape(-1),y.reshape(-1))
print(aft_likelihood_torch(linear_predictor.reshape(-1),y.reshape(-1)))

print(linear_predictor.reshape(-1).shape,y.reshape(-1).shape)

tensor(-1.5339864492, grad_fn=<MulBackward0>)
torch.Size([10]) torch.Size([10])


In [None]:
y.shape, linear_predictor.shape

(torch.Size([10, 1]), torch.Size([10, 1]))

## AFT Loss Numba

In [None]:
y, linear_predictor, time, event = aft_data(type='np')
aft_loss_own_numba = aft_likelihood(
     y, linear_predictor
)
aft_likelihood(
     y, linear_predictor
)

1.533986284161934

### Loss Function Comparison

In [None]:
np.allclose(aft_loss_own_torch.detach().numpy() ,aft_loss_own_torch.detach().numpy(),aft_loss_own_numba)

True

## Gradient Comparison

In [None]:

def aft_data(type='np'):
    #h2==0 scenario
    if type=='np':
        linear_predictor = np.array([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038]) #.reshape(1,10)
        y = np.array([1, -3, -3, -4, -7,  8,  9,  -11,  13,  16],dtype=np.float32) #.reshape(1,10)
        time = np.array([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]])
        event = np.array([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=np.float32)
    if type=='torch':
        linear_predictor = torch.torch.tensor([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038], requires_grad=True)
        y = torch.torch.tensor([1., -3., -3., -4., -7.,  8.,  9.,  -11.,  13.,  16.], requires_grad=True)
        time = torch.torch.tensor([ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16])
        event = torch.torch.tensor([1, 0, 0, 0, 0, 1, 1, 0, 1, 1],dtype=torch.float32)
        print('y shape', y.shape)
        print('linear_predictor', linear_predictor.shape)

    return y, linear_predictor, time, event

### Gradient Own Pytorch



In [None]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_own_torch = aft_likelihood_torch(linear_predictor,y)
print('loss', aft_loss_own_torch)
aft_loss_own_torch.backward()
grad_torch = linear_predictor.grad.numpy()
print('grad_torch', grad_torch)

y shape torch.Size([10])
linear_predictor torch.Size([10])
loss tensor(-1.5339864492, grad_fn=<MulBackward0>)
grad_torch [ 0.06427045 -0.03018909 -0.01806675 -0.04240712 -0.05263754  0.10352549
  0.0149176  -0.04637673  0.03045879 -0.02349511]


In [None]:
y, linear_predictor, time, event = aft_data(type='np')
grad_own, hess_own = aft_objective(
    y, linear_predictor
)
grad_own

array([-0.06427046,  0.0301891 ,  0.01806675,  0.04240714,  0.05263756,
       -0.1035255 , -0.01491763,  0.04637674, -0.03045881,  0.0234951 ])

### Gradient Function Comparison

In [None]:
np.allclose(-grad_own, grad_torch)

True

## Hessian Function Comparison

In [None]:
y, linear_predictor, time, event = aft_data(type='np')
grad_own, hess_own = aft_objective(
    y, linear_predictor
)
hess_own

array([0.04928495, 0.03086066, 0.02215578, 0.02689906, 0.00218753,
       0.01922084, 0.10613931, 0.04928495, 0.10754386, 0.07927256])

In [None]:
from torch.autograd.functional import hessian as hess_torch
y, linear_predictor, time, event = aft_data(type='torch')
print(y.shape, linear_predictor.shape)
hessian_matrix = hess_torch(aft_likelihood_torch, (linear_predictor, y), create_graph=True)
diag_hessian = hessian_matrix[0][0].diag()
diag_hessian

y shape torch.Size([10])
linear_predictor torch.Size([10])
torch.Size([10]) torch.Size([10])


tensor([ 0.1198979765, -0.0308606531, -0.0221557729, -0.0268990565,
        -0.0021875498, -0.0192208346, -0.1061393693,  0.0348668434,
        -0.1075438932, -0.0792726427], grad_fn=<DiagBackward0>)

In [None]:
np.isclose(hess_own,-diag_hessian.detach().numpy())

array([False,  True,  True,  True,  True,  True,  True, False,  True,
        True])

In [None]:
event

tensor([1., 0., 0., 0., 0., 1., 1., 0., 1., 1.])

In [None]:
def transform_back_torch(y: torch.torch.tensor) -> tuple[torch.torch.tensor, torch.torch.tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time, event

def eaftloss(out, y): ##loss function for AFT or EH
    time, delta = transform_back_torch(y)
    print(time.shape, delta.shape)
    delta = delta.float()
    time = time.float()
    ###loss function for AFT
    n = len(delta)
    print('aft')
    h = 1.30*math.pow(n,-0.2)
    #h 1.304058*math.pow(n,-0.2)  ## 1.304058*n^(-1/5) or 1.587401*math.pow(n,-0.333333) 1.587401*n^(-1/3)
    time = time.view(n,1)
    delta = delta.view(n,1)
    
    # R = g(Xi) + log(Oi)
    R = torch.add(out,torch.log(time)) 
    
    # Rj - Ri
    rawones = torch.ones([1,n], dtype = out.dtype)
    R1 = torch.mm(R,rawones)
    R2 = torch.mm(torch.t(rawones),torch.t(R))
    DR = R1 - R2 
    
    # K[(Rj-Ri)/h]
    K = normal_density(DR/h)
    Del = torch.mm(delta, rawones)
    DelK = Del*K 
    
    # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
    Dk = torch.sum(DelK, dim=0)/(n*h)
    
    # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
    log_Dk = torch.log(Dk)     
    A = torch.t(delta)*log_Dk/n   
    S1 = A.sum()  
    
    ncdf=torch.distributions.normal.Normal(torch.torch.tensor([0.0], dtype = out.dtype), torch.torch.tensor([1.0], dtype = out.dtype)).cdf
    P = ncdf(DR/h)
    CDF_sum = torch.sum(P, dim=0)/n
    Q = torch.log(CDF_sum)
    S2 = -(delta*Q.view(n,1)).sum()/n
            
    S0 = -(delta*torch.log(time)).sum()/n
    
    S = S0 + S1 + S2 
    S = -S
    return S

def normal_density(a):  
    b = 0.3989423*torch.exp(-0.5*torch.pow(a,2.0))
    return b
def aft_data(type='np'):
    #h2==0 scenario
    if type=='np':
        linear_predictor = np.array([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038]) #.reshape(1,10)
        y = np.array([1, -3, -3, -4, -7,  8,  9,  -11,  13,  16],dtype=np.float32) #.reshape(1,10)
        time = np.array([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]])
        event = np.array([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=np.float32)
    if type=='torch':
        linear_predictor = torch.torch.tensor([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038], requires_grad=True)
        y = torch.torch.tensor([1., -3., -3., -4., -7.,  8.,  9.,  -11.,  13.,  16.], requires_grad=True)
        time = torch.torch.tensor([ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16])
        event = torch.torch.tensor([1, 0, 0, 0, 0, 1, 1, 0, 1, 1],dtype=torch.float32)
        print('y shape', y.shape)
        print('linear_predictor', linear_predictor.shape)

    return y, linear_predictor, time, event
from torch.autograd.functional import hessian as hess_torch
y, linear_predictor, time, event = aft_data(type='torch')
print(y.shape, linear_predictor.shape)
hessian_matrix = hess_torch(eaftloss, (linear_predictor, y), create_graph=True)
diag_hessian = hessian_matrix[0][0].diag()
diag_hessian

y shape torch.Size([10])
linear_predictor torch.Size([10])
torch.Size([10]) torch.Size([10])
torch.Size([10]) torch.Size([10])
aft


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x10 and 1x10)

In [None]:
y.size()

torch.Size([10])