In [1]:
import numpy as np
import pandas as pd
from xgbsurv.models.utils import transform
from xgbsurv.models.breslow_final import breslow_likelihood, breslow_objective, transform_back
from scipy.special import logsumexp
import jax.numpy as jnp
from jax import jit, grad, hessian
import jax.scipy.special as jsp
import time as t
from math import log

In [2]:
## Load Simulation Data

In [3]:
#hazard = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_preds_1000.csv').to_numpy()
hazard = log_hazard = np.random.normal(0, 1, 1000)
df = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_1000.csv')
#df.event = 1
df.sort_values(by='time', inplace=True)
df.head(2)

Unnamed: 0,x_1,x_2,x_3,x_4,x_5,time,event
104,10.782665,6.321241,0.154621,19.20857,17.674405,0.00033,1.0
638,4.839507,9.959547,0.141726,19.185224,12.58512,0.000471,1.0


In [4]:
#y = transform(df.time.to_numpy(), df.event.to_numpy())

def breslow_likelihood(log_partial_hazard, time, event):

    # Assumes times have been sorted beforehand.
    partial_hazard = np.exp(log_partial_hazard)
    n_events = np.sum(event)
    n_samples = time.shape[0]
    #print(n_samples)
    previous_time = time[0]
    risk_set_sum = 0
    likelihood = 0
    set_count = 0
    accumulated_sum = 0

    for i in range(n_samples):
        risk_set_sum += partial_hazard[i]

    for k in range(n_samples):
        current_time = time[k]
        if current_time > previous_time:
            # correct set-count, have to go back to set the different hazards for the ties
            likelihood -= set_count * log(risk_set_sum)
            risk_set_sum -= accumulated_sum
            set_count = 0
            accumulated_sum = 0

        if event[k]:
            set_count += 1
            likelihood += log_partial_hazard[k]

        previous_time = current_time
        accumulated_sum += partial_hazard[k]
    #print(likelihood)
    final_likelihood = -likelihood / n_events #n_samples
    return final_likelihood

In [5]:
## Function
def get_risk_matrix(time):
    return (np.outer(time, time) >= jnp.square(time)).astype(int).T

def cox_ph_loss(log_partial_hazard, time, event):
    # this order seems to be required, albeit not working with check_grad
    #print('log_partial_hazard',log_partial_hazard)
    #print('y',y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - jsp.logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    #print('loss', loss)
    # take negative loss to align with negative gradient
    return -loss/jnp.sum(event)

time = jnp.array(df.time.to_numpy())
event = jnp.array(df.event.to_numpy())
#print(cox_ph_loss(hazard,time, event))
#grad(cox_ph_loss)(hazard,time, event)
#hess = hessian(cox_ph_loss)(hazard,time, event)
#np.diag(hess) # this is the correct solution!

In [6]:
## Compare times

In [7]:
def function1(hazard,time, event):
    gradient = grad(cox_ph_loss)(hazard,time, event)
    hess = hessian(cox_ph_loss)(hazard,time, event)
    return gradient, np.diag(hess)
    

def function2(hazard,time, event):
    return breslow_likelihood(hazard,time, event)

path = '/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data'
def comparison(num_runs = 10, size=1000):
    hazard = log_hazard = np.random.normal(0, 1, size)
    df = pd.read_csv(path+'/survival_simulation_'+str(size)+'.csv')
    df.sort_values(by='time', inplace=True)
    time = jnp.array(df.time.to_numpy())
    event = jnp.array(df.event.to_numpy())
    # Empty list to store the execution times
    function1_times = []
    function2_times = []

    # Loop to run each function and record the execution times
    for i in range(num_runs):
        start_time = t.time()
        function1(hazard,time, event)
        end_time = t.time()
        function1_times.append(end_time - start_time)

        start_time = t.time()
        function2(hazard,time, event)
        end_time = t.time()
        function2_times.append(end_time - start_time)

    # Calculate the mean and standard deviation of the execution times for each function
    function1_mean = sum(function1_times) / len(function1_times)
    function1_std = pd.Series(function1_times).std()
    function2_mean = sum(function2_times) / len(function2_times)
    function2_std = pd.Series(function2_times).std()

    # Create a Pandas dataframe to display the results
    df = pd.DataFrame({
        'Function': ['Breslow_Jax', 'Breslow_Explicit'],
        'Mean': [function1_mean, function2_mean],
        'Standard Deviation': [function1_std, function2_std],
        'Sample Size': [size, size],
        'Number Repetitions': [num_runs, num_runs]
    })
    return df

df_1000 = comparison(num_runs = 50, size=1000)
print(df_1000.to_latex(index=False))
df_1000.to_csv(path+'/results/breslow_jax_comparison.csv', index=False)
df_1000

  print(df_1000.to_latex(index=False))


\begin{tabular}{lrrrr}
\toprule
        Function &      Mean &  Standard Deviation &  Sample Size &  Number Repetitions \\
\midrule
     Breslow\_Jax & 22.208190 &            1.330617 &         1000 &                  50 \\
Breslow\_Explicit &  0.292885 &            0.005615 &         1000 &                  50 \\
\bottomrule
\end{tabular}



NameError: name 'dff' is not defined

In [None]:
df_10000 = comparison(num_runs = 50, size=10000)
df_10000

: 

: 

In [None]:
dff = pd.concat([df_1000,df_10000])
dff

In [None]:
print(df_1000.to_latex(index=False))

In [None]:
dff.to_csv(path+'/results/breslow_jax_comparison.csv', index=False)