In [1]:
import numpy as np
from xgbsurv.models.utils import transform, transform_back
from xgbsurv.models.eh_final import eh_likelihood, eh_gradient
from xgbsurv.models.eh_ah_final import ah_likelihood, ah_objective
from xgbsurv.models.eh_aft_final import aft_likelihood, aft_objective
import sys
sys.path.append('/Users/JUSC/Documents/xgbsurv_benchmarking/deep_learning/')
from loss_functions_pytorch import eh_likelihood_torch, eh_likelihood_torch_2, aft_likelihood_torch, ah_likelihood_torch
import torch
import math
torch.set_printoptions(precision=10)
from torch.autograd.functional import hessian

In [2]:
# create data function

def aft_data(type='np'):
    #h2==0 scenario
    if type=='np':
        linear_predictor = np.array([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038]) #.reshape(1,10)
        y = np.array([1, -3, -3, -4, -7,  8,  9,  -11,  13,  16],dtype=np.float32) #.reshape(1,10)
        time = np.array([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]])
        event = np.array([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=np.float32)
    if type=='torch':
        linear_predictor = torch.tensor([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038], requires_grad=True).reshape(10,1)
        y = torch.tensor([[1, -3, -3, -4, -7,  8,  9,  -11,  13,  16]]).reshape(10,1)
        time = torch.tensor([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]],dtype=torch.float32)
        event = torch.tensor([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=torch.float32).reshape(10,1)

    return y, linear_predictor, time, event

In [3]:
y, linear_predictor, time, event = aft_data(type='torch')

## Structure

- Compare loss to original function

## 1. Compare loss to original function



In [4]:
# EH loss from paper

def eaftloss(out, time, delta): ##loss function for AFT or EH
    ia, ib = out.size()
    if ib == 1: ###loss function for AFT
        n = len(delta)
        #print(n)
        h = 1.30*math.pow(n,-0.2)
        #h 1.304058*math.pow(n,-0.2)  ## 1.304058*n^(-1/5) or 1.587401*math.pow(n,-0.333333) 1.587401*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(out,torch.log(time)) 
        
        # Rj - Ri
        rawones = torch.ones([1,n], dtype = out.dtype)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)     
        A = torch.t(delta)*log_Dk/n   
        S1 = A.sum()  
        
        ncdf=torch.distributions.normal.Normal(torch.tensor([0.0], dtype = out.dtype), torch.tensor([1.0], dtype = out.dtype)).cdf
        P = ncdf(DR/h)
        CDF_sum = torch.sum(P, dim=0)/n
        Q = torch.log(CDF_sum)
        S2 = -(delta*Q.view(n,1)).sum()/n
             
        S0 = -(delta*torch.log(time)).sum()/n
        
        S = S0 + S1 + S2 
        S = -S
    else: ### loss function for Extended hazard model
        n = len(out[:,0])
        h = 1.30*math.pow(n,-0.2)  ## or 1.59*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        g1 = out[:,0].view(n,1)
        g2 = out[:,1].view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(g1,torch.log(time)) 
        
        S1 =  (delta*g2).sum()/n
        S2 = -(delta*R).sum()/n
        
        # Rj - Ri
        rawones = torch.ones(1,n)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)  ## Dk would be zero as learning rate too large!
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)    
        
        S3 = (torch.t(delta)*log_Dk).sum()/n    
        
        # Phi((Rj-Ri)/h)
        ncdf=torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])).cdf
        P = ncdf(DR/h) 
        L = torch.exp(g2-g1)
        LL = torch.mm(L,rawones)
        LP_sum = torch.sum(LL*P, dim=0)/n
        Q = torch.log(LP_sum)
        
        S4 = -(delta*Q.view(n,1)).sum()/n
        
        S = S1 + S2 + S3 + S4  
        S = -S
    return S

def normal_density(a):  
    b = 0.3989423*torch.exp(-0.5*torch.pow(a,2.0))
    return b

### AFT Loss Original Paper

In [5]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_paper = eaftloss(linear_predictor, time, event)
print(eaftloss(linear_predictor, time, event))

tensor(1.5339863300, grad_fn=<NegBackward0>)


### AFT Loss Pytorch

check bandwidth, check dims vector for aft, go through code step by step

In [7]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_own_torch = aft_likelihood_torch(linear_predictor.reshape(-1),y.reshape(-1))
print(aft_likelihood_torch(linear_predictor.reshape(-1),y.reshape(-1)))

print(linear_predictor.reshape(-1).shape,y.reshape(-1).shape)

tensor(-1.5339864492, grad_fn=<MulBackward0>)
torch.Size([10]) torch.Size([10])


In [None]:
y.shape, linear_predictor.shape

(torch.Size([10, 1]), torch.Size([10, 1]))

## AFT Loss Numba

In [None]:
y, linear_predictor, time, event = aft_data(type='np')
aft_loss_own_numba = aft_likelihood(
     y, linear_predictor
)
aft_likelihood(
     y, linear_predictor
)

1.533986284161934

### Loss Function Comparison

In [None]:
np.allclose(aft_loss_own_torch.detach().numpy() ,aft_loss_own_torch.detach().numpy(),aft_loss_own_numba)

True

## Gradient Comparison

In [None]:

def aft_data(type='np'):
    #h2==0 scenario
    if type=='np':
        linear_predictor = np.array([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038]) #.reshape(1,10)
        y = np.array([1, -3, -3, -4, -7,  8,  9,  -11,  13,  16],dtype=np.float32) #.reshape(1,10)
        time = np.array([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]])
        event = np.array([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=np.float32)
    if type=='torch':
        linear_predictor = torch.tensor([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038], requires_grad=True)
        y = torch.tensor([1., -3., -3., -4., -7.,  8.,  9.,  -11.,  13.,  16.], requires_grad=True)
        time = torch.tensor([ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16])
        event = torch.tensor([1, 0, 0, 0, 0, 1, 1, 0, 1, 1],dtype=torch.float32)
        print('y shape', y.shape)
        print('linear_predictor', linear_predictor.shape)

    return y, linear_predictor, time, event

### Gradient Own Pytorch



In [None]:
y, linear_predictor, time, event = aft_data(type='torch')
aft_loss_own_torch = aft_likelihood_torch(linear_predictor,y)
print('loss', aft_loss_own_torch)
aft_loss_own_torch.backward()
grad_torch = linear_predictor.grad.numpy()
print('grad_torch', grad_torch)

y shape torch.Size([10])
linear_predictor torch.Size([10])
loss tensor(-1.5339864492, grad_fn=<MulBackward0>)
grad_torch [ 0.06427045 -0.03018909 -0.01806675 -0.04240712 -0.05263754  0.10352549
  0.0149176  -0.04637673  0.03045879 -0.02349511]


In [None]:
y, linear_predictor, time, event = aft_data(type='np')
grad_own, hess_own = aft_objective(
    y, linear_predictor
)
grad_own

array([-0.06427046,  0.0301891 ,  0.01806675,  0.04240714,  0.05263756,
       -0.1035255 , -0.01491763,  0.04637674, -0.03045881,  0.0234951 ])

### Gradient Function Comparison

In [None]:
np.allclose(-grad_own, grad_torch)

True

## Hessian Function Comparison

In [10]:
y, linear_predictor, time, event = aft_data(type='np')
grad_own, hess_own = aft_objective(
    y, linear_predictor
)
hess_own

array([0.04928495, 0.03086066, 0.02215578, 0.02689906, 0.00218753,
       0.01922084, 0.10613931, 0.04928495, 0.10754386, 0.07927256])

In [11]:
from torch.autograd.functional import hessian as hess_torch
y, linear_predictor, time, event = aft_data(type='torch')
print(y.shape, linear_predictor.shape)
hessian_matrix = hess_torch(aft_likelihood_torch, (linear_predictor, y), create_graph=True)
diag_hessian = hessian_matrix[0][0].diag()
diag_hessian

y shape torch.Size([10])
linear_predictor torch.Size([10])
torch.Size([10]) torch.Size([10])


tensor([ 0.1198979765, -0.0308606531, -0.0221557729, -0.0268990565,
        -0.0021875498, -0.0192208346, -0.1061393693,  0.0348668434,
        -0.1075438932, -0.0792726427], grad_fn=<DiagBackward0>)

In [16]:
np.isclose(hess_own,-diag_hessian.detach().numpy())

array([False,  True,  True,  True,  True,  True,  True, False,  True,
        True])

In [17]:
event

tensor([1., 0., 0., 0., 0., 1., 1., 0., 1., 1.])

In [7]:
def transform_back_torch(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Transforms XGBoost digestable format variable y into time and event.

    Parameters
    ----------
    y : npt.NDArray[float]
        Array containing survival time and event where negative value is taken as censored event.

    Returns
    -------
    tuple[npt.NDArray[float],npt.NDArray[int]]
        Survival time and event.
    """
    time = torch.abs(y)
    event = (torch.abs(y) == y)
    event = event # for numba
    return time, event

def eaftloss(out, y): ##loss function for AFT or EH
    time, delta = transform_back_torch(y)
    print(time.shape, delta.shape)
    delta = delta.float()
    time = time.float()
    ###loss function for AFT
    n = len(delta)
    print('aft')
    h = 1.30*math.pow(n,-0.2)
    #h 1.304058*math.pow(n,-0.2)  ## 1.304058*n^(-1/5) or 1.587401*math.pow(n,-0.333333) 1.587401*n^(-1/3)
    time = time.view(n,1)
    delta = delta.view(n,1)
    
    # R = g(Xi) + log(Oi)
    R = torch.add(out,torch.log(time)) 
    
    # Rj - Ri
    rawones = torch.ones([1,n], dtype = out.dtype)
    R1 = torch.mm(R,rawones)
    R2 = torch.mm(torch.t(rawones),torch.t(R))
    DR = R1 - R2 
    
    # K[(Rj-Ri)/h]
    K = normal_density(DR/h)
    Del = torch.mm(delta, rawones)
    DelK = Del*K 
    
    # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
    Dk = torch.sum(DelK, dim=0)/(n*h)
    
    # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
    log_Dk = torch.log(Dk)     
    A = torch.t(delta)*log_Dk/n   
    S1 = A.sum()  
    
    ncdf=torch.distributions.normal.Normal(torch.tensor([0.0], dtype = out.dtype), torch.tensor([1.0], dtype = out.dtype)).cdf
    P = ncdf(DR/h)
    CDF_sum = torch.sum(P, dim=0)/n
    Q = torch.log(CDF_sum)
    S2 = -(delta*Q.view(n,1)).sum()/n
            
    S0 = -(delta*torch.log(time)).sum()/n
    
    S = S0 + S1 + S2 
    S = -S
    return S

def normal_density(a):  
    b = 0.3989423*torch.exp(-0.5*torch.pow(a,2.0))
    return b
def aft_data(type='np'):
    #h2==0 scenario
    if type=='np':
        linear_predictor = np.array([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038]) #.reshape(1,10)
        y = np.array([1, -3, -3, -4, -7,  8,  9,  -11,  13,  16],dtype=np.float32) #.reshape(1,10)
        time = np.array([[ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16]])
        event = np.array([[1, 0, 0, 0, 0, 1, 1, 0, 1, 1]],dtype=np.float32)
    if type=='torch':
        linear_predictor = torch.tensor([0.67254923,
        0.86077982,
        0.43557393,
        0.94059047,
        0.8446509 ,
        0.23657039,
        0.74629685,
        0.99700768,
        0.28182768,
        0.44495038], requires_grad=True)
        y = torch.tensor([1., -3., -3., -4., -7.,  8.,  9.,  -11.,  13.,  16.], requires_grad=True)
        time = torch.tensor([ 1,  3,  3,  4,  7,  8,  9, 11, 13, 16])
        event = torch.tensor([1, 0, 0, 0, 0, 1, 1, 0, 1, 1],dtype=torch.float32)
        print('y shape', y.shape)
        print('linear_predictor', linear_predictor.shape)

    return y, linear_predictor, time, event
from torch.autograd.functional import hessian as hess_torch
y, linear_predictor, time, event = aft_data(type='torch')
print(y.shape, linear_predictor.shape)
hessian_matrix = hess_torch(eaftloss, (linear_predictor, y), create_graph=True)
diag_hessian = hessian_matrix[0][0].diag()
diag_hessian

y shape torch.Size([10])
linear_predictor torch.Size([10])
torch.Size([10]) torch.Size([10])
torch.Size([10]) torch.Size([10])
aft


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x10 and 1x10)

In [8]:
y.size()

torch.Size([10])