In [1]:
import numpy as np
import pandas as pd
from xgbsurv.models.utils import transform
from xgbsurv.models.breslow_final import breslow_likelihood, breslow_objective, transform_back
from scipy.special import logsumexp
import jax.numpy as jnp
from jax import  grad, hessian #jit,
import jax.scipy.special as jsp
import time as t
from math import log
import numba
from numba import jit #as njit

In [2]:
## Compare Timing of Loss Functions

In [3]:
# vectorized coxph

def get_risk_matrix(time):
    return (np.outer(time, time) >= np.square(time)).astype(int).T

def cox_ph_loss(log_partial_hazard, time, event):
    #time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    return -loss/event.sum()

# sksurv approach

def cox_ph_loss_sksurv(y_pred, time, event):
    # sksurv approach without cython
    #time, event = transform_back(y)
    n_samples = event.shape[0]
    loss = 0

    for i in range(n_samples):
        at_risk = 0
        for j in range(n_samples):
            if time[j] >= time[i]:
                at_risk += np.exp(y_pred[j])
        loss += event[i] * (y_pred[i] - np.log(at_risk))
    return - loss/event.sum()

def breslow_likelihood(log_partial_hazard, time, event):

    # Assumes times have been sorted beforehand.
    partial_hazard = np.exp(log_partial_hazard)
    n_events = np.sum(event)
    n_samples = time.shape[0]
    #print(n_samples)
    previous_time = time[0]
    risk_set_sum = 0
    likelihood = 0
    set_count = 0
    accumulated_sum = 0

    for i in range(n_samples):
        risk_set_sum += partial_hazard[i]

    for k in range(n_samples):
        current_time = time[k]
        if current_time > previous_time:
            # correct set-count, have to go back to set the different hazards for the ties
            likelihood -= set_count * log(risk_set_sum)
            risk_set_sum -= accumulated_sum
            set_count = 0
            accumulated_sum = 0

        if event[k]:
            set_count += 1
            likelihood += log_partial_hazard[k]

        previous_time = current_time
        accumulated_sum += partial_hazard[k]
    #print(likelihood)
    final_likelihood = -likelihood / n_events #n_samples
    return final_likelihood

@jit(nopython=True)
def breslow_likelihood_numba(log_partial_hazard, time, event):

    # Assumes times have been sorted beforehand.
    partial_hazard = np.exp(log_partial_hazard)
    n_events = np.sum(event)
    n_samples = time.shape[0]
    #print(n_samples)
    previous_time = time[0]
    risk_set_sum = 0
    likelihood = 0
    set_count = 0
    accumulated_sum = 0

    for i in range(n_samples):
        risk_set_sum += partial_hazard[i]

    for k in range(n_samples):
        current_time = time[k]
        if current_time > previous_time:
            # correct set-count, have to go back to set the different hazards for the ties
            likelihood -= set_count * log(risk_set_sum)
            risk_set_sum -= accumulated_sum
            set_count = 0
            accumulated_sum = 0

        if event[k]:
            set_count += 1
            likelihood += log_partial_hazard[k]

        previous_time = current_time
        accumulated_sum += partial_hazard[k]
    #print(likelihood)
    final_likelihood = -likelihood / n_events #n_samples
    return final_likelihood

In [4]:
## Run comparison
def function1(hazard,time, event):
    return cox_ph_loss(hazard, time, event)

def function2(hazard,time, event):
    return breslow_likelihood(hazard,time, event)   

def function3(hazard,time, event):
    return breslow_likelihood_numba(hazard, time, event)

path = '/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data'
def comparison(num_runs = 10, size=1000):
    hazard = log_hazard = np.random.normal(0, 1, size)
    df = pd.read_csv(path+'/survival_simulation_'+str(size)+'.csv')
    df.sort_values(by='time', inplace=True)
    time = df.time.to_numpy()
    event = df.event.to_numpy()
    # Empty list to store the execution times
    function1_times = []
    function2_times = []
    function3_times = []

    # Loop to run each function and record the execution times
    for i in range(num_runs):
        print('Running Function 1')
        start_time = t.time()
        function1(hazard,time, event)
        end_time = t.time()
        function1_times.append(end_time - start_time)

        print('Running Function 2')
        start_time = t.time()
        function2(hazard,time, event)
        end_time = t.time()
        function2_times.append(end_time - start_time)

        print('Running Function 3')
        start_time = t.time()
        function3(hazard,time, event)
        end_time = t.time()
        function3_times.append(end_time - start_time)

    # Calculate the mean and standard deviation of the execution times for each function
    function1_mean = sum(function1_times) / len(function1_times)
    function1_std = pd.Series(function1_times).std()
    function2_mean = sum(function2_times) / len(function2_times)
    function2_std = pd.Series(function2_times).std()
    function3_mean = sum(function3_times) / len(function3_times)
    function3_std = pd.Series(function3_times).std()

    # Create a Pandas dataframe to display the results
    df = pd.DataFrame({
        'Function': ['Standard Vectorized CoxPH', 'Breslow', 'Breslow Numba'],
        'Mean': [function1_mean, function2_mean, function3_mean],
        'Standard Deviation': [function1_std, function2_std, function3_std],
        'Sample Size': [size, size, size],
        'Number Repetitions': [num_runs, num_runs, num_runs]
    })
    return df

df_1000 = comparison(num_runs = 100, size=1000)
df_10000 = comparison(num_runs = 100, size=10000)
df_100000 = comparison(num_runs = 100, size=100000)


Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Function 2
Running Function 3
Running Function 1
Running Func

: 

: 

In [None]:
dff = pd.concat([df_1000,df_10000, df_100000])
dff.to_csv(path+'/results/breslow_numba_comparison.csv', index=False)
print(dff.to_latex(index=False))

In [None]:
## Load Simulation Data

In [179]:
#hazard = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_preds_1000.csv').to_numpy()
hazard = log_hazard = np.random.normal(0, 1, 1000)
df = pd.read_csv('/Users/JUSC/Documents/xgbsurv_benchmarking/implementation_testing/simulation_data/survival_simulation_1000.csv')
#df.event = 1
df.sort_values(by='time', inplace=True)
df.head(2)

Unnamed: 0,x_1,x_2,x_3,x_4,x_5,time,event
104,10.782665,6.321241,0.154621,19.20857,17.674405,0.00033,1.0
638,4.839507,9.959547,0.141726,19.185224,12.58512,0.000471,1.0


In [180]:
y = transform(df.time.to_numpy(), df.event.to_numpy())

# Verify loss function results are the same

In [181]:
def get_risk_matrix(time):
    return (np.outer(time, time) >= np.square(time)).astype(int).T

def cox_ph_loss(y, log_partial_hazard):
    # this order seems to be required, albeit not working with check_grad
    #print('log_partial_hazard',log_partial_hazard)
    #print('y',y)
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    #print('loss', loss)
    # take negative loss to align with negative gradient
    return -loss/event.sum()

In [182]:
def cox_ph_loss_sksurv(y, y_pred):
    # sksurv approach
    time, event = transform_back(y)
    n_samples = event.shape[0]
    loss = 0


    for i in range(n_samples):
        at_risk = 0
        for j in range(n_samples):
            if time[j] >= time[i]:
                at_risk += np.exp(y_pred[j])
        loss += event[i] * (y_pred[i] - np.log(at_risk))
    #print(np.log(at_risk))
    return - loss/event.sum()

In [183]:
loss_breslow_0 = breslow_likelihood(y, hazard)
loss_breslow_0

7.063487791499207

In [184]:
loss_breslow_1 = cox_ph_loss(y, hazard)
loss_breslow_1

7.063487791499208

In [185]:
loss_breslow_2 = cox_ph_loss_sksurv(y, hazard)
loss_breslow_2

7.063487791499208

## Comparison of objective function results

In [186]:
# old way of risk matrix
def get_risk_matrix(time):
    return (np.outer(time, time) >= np.square(time)).astype(int).T

# new way, has to swapped in subsequent functions
def risk_sum_fast(log_hazard, time, event):
    if np.all(time[:-1] <= time[1:]) == False:
        order = np.argsort(time)
        time = time[order]
        event = event[order]
        log_hazard = log_hazard[order]
    h = np.exp(log_hazard)

    unique_times, ind, counts = np.unique(time, return_index=True, return_counts=True)
    n = event.shape[0]
    tie_rep = np.repeat(n-ind,counts)
    breaks = np.array([np.zeros(n), tie_rep]).T.flatten()
    breaks = breaks.astype('int64')
    # reverse order
    h = h[::-1]
    # add zero because closed brackets behaviour of reduceat
    to_sum = np.append(h,0)

    rh1 = np.add.reduceat(to_sum, breaks)[::2]
    return rh1

def cox_ph_loss(y, log_partial_hazard):
    # this order seems to be required, albeit not working with check_grad
    #print('log_partial_hazard',log_partial_hazard)
    #print('y',y)
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    #print('loss', loss)
    # take negative loss to align with negative gradient
    return -loss/y.shape[0]

def cox_ph_denominator(log_partial_hazard, risk_matrix):
    return np.sum(
        risk_matrix * np.array([np.exp(log_partial_hazard)] * log_partial_hazard.shape[0]), 
        axis=1)

def cox_ph_gradient(y, log_partial_hazard):
    n = y.shape[0]
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    denominator = cox_ph_denominator(log_partial_hazard, risk_matrix)
    numerator = np.exp(log_partial_hazard)
    gradient = event - np.sum(
    (
        event.repeat(event.shape[0]).reshape((n, n)).T
        * get_risk_matrix(time).T
        * numerator.repeat(event.shape[0]).reshape((n, n))
    )
    / denominator.repeat(event.shape[0]).reshape((n, n)).T,
    axis=1,)
    return -gradient 

def cox_ph_denominator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    denominator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1)
    return np.square(denominator)

def cox_ph_numerator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    numerator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1) - np.sum(risk_matrix * np.square(np.exp(log_partial_hazard)),axis=1)
    return np.exp(log_partial_hazard)*numerator

def cox_ph_hessian(y, log_partial_hazard):
    n = y.shape[0]
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    numerator = cox_ph_numerator_hess(log_partial_hazard, time)
    denominator = cox_ph_denominator_hess(log_partial_hazard, time)
    hess = -np.sum(event[:,None]*risk_matrix*(numerator / denominator), axis=0)
    return hess

def cox_ph_objective(y, log_partial_hazard):
    # input order for xgb sklearn swapped
    gradient = cox_ph_gradient(y, log_partial_hazard)
    hessian = cox_ph_hessian(y, log_partial_hazard)
    return gradient, hessian

def get_risk_matrix(time):
    return (np.outer(time, time) >= np.square(time)).astype(int).T

def cox_ph_numerator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    numerator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1) - np.sum(risk_matrix * np.square(np.exp(log_partial_hazard)),axis=1)
    #print('risk_part',np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1))
    #print(np.exp(log_partial_hazard)*numerator)
    return np.exp(log_partial_hazard)*numerator

def cox_ph_denominator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    denominator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1)
    return np.square(denominator)

def cox_ph_hessian(y, log_partial_hazard):
    time, event = transform_back(y)
    n = event.shape[0]
    risk_matrix = get_risk_matrix(time)
    numerator = cox_ph_numerator_hess(log_partial_hazard, time)
    denominator = cox_ph_denominator_hess(log_partial_hazard, time)
    #print('num', numerator)
    #print('denominator', denominator)
    #print('(numerator / denominator)',(numerator / denominator))
    hess = -np.sum(event[:,None]*risk_matrix*(numerator / denominator), axis=0)
    return hess

def cox_ph_objective_2(y, log_partial_hazard):
    # input order for xgb sklearn swapped
    gradient = cox_ph_gradient(y, log_partial_hazard)
    hessian = cox_ph_hessian(y, log_partial_hazard)
    return gradient, hessian

# hessian nv version

def cox_ph_denominator_hess_nv(log_partial_hazard, time):
    denominator = np.zeros(time.shape[0])
    for j in range(time.shape[0]): 
        for k in range(time.shape[0]):
            denominator[j] += (time[k] >= time[j]) * np.exp(log_partial_hazard[k])
    return np.square(denominator)

def cox_ph_numerator_hess_nv(log_partial_hazard, time):
    numerator = np.zeros(time.shape[0])
    for j in range(time.shape[0]):
        for k in range(time.shape[0]):
            numerator[j] += (time[k] >= time[j]) * np.exp(log_partial_hazard[k]) - (time[k] >= time[j]) * np.square(np.exp(log_partial_hazard[k]))
    return np.exp(log_partial_hazard)*numerator

def cox_ph_hessian_nv(y, log_partial_hazard):
    time, event = transform_back(y)
    hess = np.zeros(time.shape[0])
    numerator = cox_ph_numerator_hess_nv(log_partial_hazard, time)
    denominator = cox_ph_denominator_hess_nv(log_partial_hazard, time)
    for i in range(time.shape[0]):
        for j in range(time.shape[0]):
            hess[i] -= event[j] * (time[i] >= time[j]) * numerator[i] / denominator[i]
    return hess

In [187]:
grad_0, hess_0 = breslow_objective(y,hazard)
grad_1, hess_1 = cox_ph_objective(y, hazard)
grad_2, hess_2 = cox_ph_objective_2(y, hazard)

In [188]:
# grad comparison
np.allclose(grad_0, grad_1, atol=1e-08)

True

In [189]:
# hessian comparison
np.allclose(hess_0, hess_1, atol=1e-04)

False

In [190]:
# hessian comparison
np.allclose(hess_1, hess_2, atol=1e-04)

True

In [191]:
hess_3 = cox_ph_hessian_nv(y, hazard) #/df.event.sum()
np.allclose(hess_2, hess_3, atol=1e-04)

True

In [192]:
def update_risk_sets_breslow(
    risk_set_sum, death_set_count, local_risk_set, local_risk_set_hessian
):
    local_risk_set += 1 / (risk_set_sum / death_set_count)
    local_risk_set_hessian += 1 / ((risk_set_sum**2) / death_set_count)
    return local_risk_set, local_risk_set_hessian


def calculate_sample_grad_hess(
    sample_partial_hazard, sample_event, local_risk_set, local_risk_set_hessian
):
    return (
        sample_partial_hazard * local_risk_set
    ) - sample_event, sample_partial_hazard * local_risk_set - local_risk_set_hessian * (
        sample_partial_hazard**2
    )


def breslow_numba(y, log_partial_hazard):
    time, event = transform_back(y)
    # Assumes times have been sorted beforehand.
    partial_hazard = np.exp(log_partial_hazard)
    samples = time.shape[0]
    risk_set_sum = 0

    for i in range(samples):
        risk_set_sum += partial_hazard[i]

    grad = np.empty(samples)
    hess = np.empty(samples)
    previous_time = time[0]
    local_risk_set = 0
    local_risk_set_hessian = 0
    death_set_count = 0
    censoring_set_count = 0
    accumulated_sum = 0

    for i in range(samples):
        sample_time = time[i]
        sample_event = event[i]
        sample_partial_hazard = partial_hazard[i]

        if previous_time < sample_time:
            if death_set_count:
                (
                    local_risk_set,
                    local_risk_set_hessian,
                ) = update_risk_sets_breslow(
                    risk_set_sum,
                    death_set_count,
                    local_risk_set,
                    local_risk_set_hessian,
                )
            for death in range(death_set_count + censoring_set_count):
                death_ix = i - 1 - death
                (grad[death_ix], hess[death_ix],) = calculate_sample_grad_hess(
                    partial_hazard[death_ix],
                    event[death_ix],
                    local_risk_set,
                    local_risk_set_hessian,
                )

            risk_set_sum -= accumulated_sum
            accumulated_sum = 0
            death_set_count = 0
            censoring_set_count = 0

        if sample_event:
            death_set_count += 1
        else:
            censoring_set_count += 1

        accumulated_sum += sample_partial_hazard
        previous_time = sample_time

    i += 1
    if death_set_count:
        local_risk_set, local_risk_set_hessian = update_risk_sets_breslow(
            risk_set_sum,
            death_set_count,
            local_risk_set,
            local_risk_set_hessian,
        )
    for death in range(death_set_count + censoring_set_count):
        death_ix = i - 1 - death
        (grad[death_ix], hess[death_ix],) = calculate_sample_grad_hess(
            partial_hazard[death_ix],
            event[death_ix],
            local_risk_set,
            local_risk_set_hessian,
        )
    return grad, hess


In [193]:
breslow_numba(y, hazard)[1]

array([5.44671097e-04, 6.22211221e-04, 1.25773136e-03, 9.10869074e-04,
       1.43205336e-02, 2.82956805e-03, 6.05908385e-03, 3.21009844e-03,
       5.36289520e-03, 1.54396532e-02, 7.76469402e-03, 3.69655093e-02,
       1.81949999e-03, 1.00688089e-01, 7.99907414e-03, 9.28839386e-03,
       3.88447671e-03, 1.28103657e-02, 7.82136787e-02, 3.50178334e-02,
       7.76343890e-03, 1.93525468e-02, 5.13030391e-03, 9.80878887e-02,
       2.23717606e-02, 1.04997004e-01, 5.21680922e-03, 1.71486759e-01,
       5.11456762e-02, 5.59801134e-02, 3.16392179e-02, 5.49021373e-03,
       1.80567480e-01, 3.26580754e-03, 4.41819739e-02, 3.99938600e-03,
       9.59955854e-02, 9.15345982e-02, 6.06423126e-02, 1.24390208e-02,
       2.06534570e-02, 1.11433753e-01, 4.46439528e-02, 6.02780284e-03,
       7.15909757e-02, 2.19445952e-02, 2.47683402e-02, 5.33494757e-02,
       4.35982756e-02, 2.19607412e-02, 9.27408761e-03, 4.88559303e-02,
       4.94941901e-02, 9.53766693e-03, 6.25102563e-02, 1.78176854e-02,
      

In [201]:
#@jax.jit
def get_risk_matrix(time):
    return (np.outer(time, time) >= jnp.square(time)).astype(int).T

#@jax.jit
def cox_ph_loss(log_partial_hazard, time, event):
    # this order seems to be required, albeit not working with check_grad
    #print('log_partial_hazard',log_partial_hazard)
    #print('y',y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - jsp.logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    #print('loss', loss)
    # take negative loss to align with negative gradient
    return -loss/jnp.sum(event)

time = jnp.array(df.time.to_numpy())
event = jnp.array(df.event.to_numpy())
print(cox_ph_loss(hazard,time, event))
grad(cox_ph_loss)(hazard,time, event)
hess = hessian(cox_ph_loss)(hazard,time, event)
np.diag(hess) # this is the correct solution!

7.063488


array([1.47208391e-06, 1.68165218e-06, 3.39927351e-06, 2.46180844e-06,
       3.87041473e-05, 7.64748256e-06, 1.63759014e-05, 8.67594099e-06,
       1.44943097e-05, 4.17287883e-05, 2.09856571e-05, 9.99067925e-05,
       4.91756782e-06, 2.72129953e-04, 2.16191202e-05, 2.51037691e-05,
       1.04985875e-05, 3.46226152e-05, 2.11388295e-04, 9.46427899e-05,
       2.09822629e-05, 5.23041854e-05, 1.38656860e-05, 2.65102426e-04,
       6.04642191e-05, 2.83775677e-04, 1.40994825e-05, 4.63477714e-04,
       1.38231568e-04, 1.51297601e-04, 8.55114122e-05, 1.48384133e-05,
       4.88020247e-04, 8.82650602e-06, 1.19410739e-04, 1.08091535e-05,
       2.59447523e-04, 2.47390795e-04, 1.63898134e-04, 3.36189732e-05,
       5.58201500e-05, 3.01172317e-04, 1.20659330e-04, 1.62913548e-05,
       1.93489148e-04, 5.93097102e-05, 6.69414585e-05, 1.44187739e-04,
       1.17833180e-04, 5.93533478e-05, 2.50650974e-05, 1.32043046e-04,
       1.33768073e-04, 2.57774809e-05, 1.68946659e-04, 4.81559036e-05,
      

In [214]:
hess_1 #/event.sum()

array([ 1.28876886e-03,  1.47315298e-03,  2.97961709e-03,  2.15898206e-03,
        3.40451419e-02,  6.71738966e-03,  1.43972845e-02,  7.63280801e-03,
        1.27599087e-02,  3.68007916e-02,  1.85239553e-02,  8.84820763e-02,
        4.34532409e-03,  2.42133035e-01,  1.87658085e-02,  2.17970483e-02,
        9.11612506e-03,  3.00720221e-02,  1.84369145e-01,  8.21753028e-02,
        1.82189652e-02,  4.54410740e-02,  1.20492895e-02,  2.31230601e-01,
        5.24648372e-02,  2.47235424e-01,  1.22127790e-02,  4.03720661e-01,
        1.18489386e-01,  1.29869999e-01,  7.34334352e-02,  1.27430973e-02,
        4.21146914e-01,  7.51370450e-03,  1.01695050e-01,  9.20530140e-03,
        2.21363943e-01,  2.11233143e-01,  1.39976596e-01,  2.87141746e-02,
        4.76771452e-02,  2.57854186e-01,  1.03231880e-01,  1.39398062e-02,
        1.65718045e-01,  5.08134674e-02,  5.73685408e-02,  1.23684524e-01,
        1.01174723e-01,  5.09873690e-02,  2.15315848e-02,  1.13482282e-01,
        1.15086299e-01,  

In [139]:
def coxph_loss(y, y_pred):
    time, event = transform_back(y)
    n_samples = event.shape[0]
    loss = 0


    for i in range(n_samples):
        at_risk = 0
        for j in range(n_samples):
            if time[j] >= time[i]:
                at_risk += np.exp(y_pred[j])
        loss += event[i] * (y_pred[i] - np.log(at_risk))
    #print(np.log(at_risk))
    return - loss/y.shape[0]

In [140]:
coxph_loss(y, hazard)

2.6272164919452132

## Vectorized Breslow Model
Two ways of calculating risk matrix
- matrix comparison
- summing approach

In [141]:
## Vectorized Breslow Loss function
# old way of risk matrix
def get_risk_matrix(time):
    return (np.outer(time, time) >= np.square(time)).astype(int).T

# new way, has to swapped in subsequent functions
def risk_sum_fast(log_hazard, time, event):
    if np.all(time[:-1] <= time[1:]) == False:
        order = np.argsort(time)
        time = time[order]
        event = event[order]
        log_hazard = log_hazard[order]
    h = np.exp(log_hazard)

    unique_times, ind, counts = np.unique(time, return_index=True, return_counts=True)
    n = event.shape[0]
    tie_rep = np.repeat(n-ind,counts)
    breaks = np.array([np.zeros(n), tie_rep]).T.flatten()
    breaks = breaks.astype('int64')
    # reverse order
    h = h[::-1]
    # add zero because closed brackets behaviour of reduceat
    to_sum = np.append(h,0)

    rh1 = np.add.reduceat(to_sum, breaks)[::2]
    return rh1

def cox_ph_loss(y, log_partial_hazard):
    # this order seems to be required, albeit not working with check_grad
    #print('log_partial_hazard',log_partial_hazard)
    #print('y',y)
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    hazard_risk = log_partial_hazard*risk_matrix
    inp = event*(log_partial_hazard - logsumexp(hazard_risk, b=risk_matrix, axis=1))
    # logsumexp numerically more stable than numpy
    loss = np.sum(inp)
    #print('loss', loss)
    # take negative loss to align with negative gradient
    return -loss

def cox_ph_denominator(log_partial_hazard, risk_matrix):
    return np.sum(
        risk_matrix * np.array([np.exp(log_partial_hazard)] * log_partial_hazard.shape[0]), 
        axis=1)

def cox_ph_gradient(y, log_partial_hazard):
    n = y.shape[0]
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    denominator = cox_ph_denominator(log_partial_hazard, risk_matrix)
    numerator = np.exp(log_partial_hazard)
    gradient = event - np.sum(
    (
        event.repeat(event.shape[0]).reshape((n, n)).T
        * get_risk_matrix(time).T
        * numerator.repeat(event.shape[0]).reshape((n, n))
    )
    / denominator.repeat(event.shape[0]).reshape((n, n)).T,
    axis=1,)
    return -gradient 

def cox_ph_denominator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    denominator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1)
    return np.square(denominator)

def cox_ph_numerator_hess(log_partial_hazard, time):
    risk_matrix = get_risk_matrix(time)
    numerator = np.sum(risk_matrix * np.exp(log_partial_hazard),axis=1) - np.sum(risk_matrix * np.square(np.exp(log_partial_hazard)),axis=1)
    return np.exp(log_partial_hazard)*numerator

def cox_ph_hessian(y, log_partial_hazard):
    n = y.shape[0]
    time, event = transform_back(y)
    risk_matrix = get_risk_matrix(time)
    numerator = cox_ph_numerator_hess(log_partial_hazard, time)
    denominator = cox_ph_denominator_hess(log_partial_hazard, time)
    hess = -np.sum(event[:,None]*risk_matrix*(numerator / denominator), axis=0)
    return hess

def cox_ph_objective(y, log_partial_hazard):
    # input order for xgb sklearn swapped
    gradient = cox_ph_gradient(y, log_partial_hazard)
    hessian = cox_ph_hessian(y, log_partial_hazard)
    return gradient, hessian

In [215]:
import pandas as pd
import time

def function1():
    # code for function 1
    time.sleep(2)

def function2():
    # code for function 2
    time.sleep(3)

# Set the number of times to run each function for accurate timing
num_runs = 10

# Empty list to store the execution times
function1_times = []
function2_times = []

# Loop to run each function and record the execution times
for i in range(num_runs):
    start_time = time.time()
    function1()
    end_time = time.time()
    function1_times.append(end_time - start_time)

    start_time = time.time()
    function2()
    end_time = time.time()
    function2_times.append(end_time - start_time)

# Calculate the mean and standard deviation of the execution times for each function
function1_mean = sum(function1_times) / len(function1_times)
function1_std = pd.Series(function1_times).std()
function2_mean = sum(function2_times) / len(function2_times)
function2_std = pd.Series(function2_times).std()

# Create a Pandas dataframe to display the results
df = pd.DataFrame({
    'Function': ['Function 1', 'Function 2'],
    'Mean': [function1_mean, function2_mean],
    'Standard Deviation': [function1_std, function2_std]
})

print(df)


     Function      Mean  Standard Deviation
0  Function 1  2.002873            0.001719
1  Function 2  3.003943            0.001672
