In [1]:
import jax.scipy as jsp
import jax
import jax.numpy as jnp
from lifelines import CoxPHFitter
from lifelines.datasets import load_rossi
import numpy as np
import pandas as pd
from jax import grad, hessian
from scipy.optimize import minimize
import warnings
warnings.filterwarnings('ignore')

In [2]:
# import data
rossi = load_rossi()
# prepare for no ties
rossi = rossi.drop_duplicates('week', keep='first')
# sort data by time
rossi = rossi.sort_values(by='week',ascending=True)
# visually verify
rossi.tail(5)

Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
299,47,1,0,22,1,0,0,1,3
106,48,1,0,19,1,0,0,0,6
68,49,1,0,35,1,1,0,1,3
35,50,1,1,20,1,1,0,1,2
3,52,0,1,23,1,1,1,1,1


In [3]:
# calculate at risk helper matrix (i.e. riskset)
# essential to capture the decreasing number of individuals at risk when taking the exp sums below.
at_risk = np.triu(np.ones(rossi.shape[0]))
at_risk.shape

(49, 49)

In [4]:


#code negative log likelihood function that we would like to minimize:
@jax.jit
def neglogp(betas, x = rossi[['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']].to_numpy(), riskset=at_risk, observed=rossi.arrest.to_numpy()):
    betas_x = betas @ x.T
    print('betas_x:',betas_x.shape)
    # now we want to sum in decreasing order of elements
    # first all are in the risk set and then every time step one less
    # this is well achieved with the np.triu function
    riskset_beta = betas_x * riskset
    print('riskset_beta:',riskset_beta.shape)
    # Compute the log of the sum of exponentials of input elements.
    # b is the weighting factor for each input element
    # i.e. we sum only over the elements of the riskset, 
    # in other words it allows us to get rid of the values exp(0) = 1 that are created without the weight.
    res_vec = (betas_x - jsp.special.logsumexp(riskset_beta, b = riskset, axis=1))
    print('ll_matrix:',res_vec.shape)
    # we sum the result on ly for those individuals where the event occurred.
    return -(observed * res_vec).sum()


In [5]:
dlike = grad(neglogp)
dlike2 = hessian(neglogp)
res = minimize(neglogp, np.ones(7), method='Newton-CG', jac=dlike, hess=dlike2)
print('Results:', res.x)

betas_x: (49,)
riskset_beta: (49, 49)
ll_matrix: (49,)
betas_x: (49,)
riskset_beta: (49, 49)
ll_matrix: (49,)
betas_x: (49,)
riskset_beta: (49, 49)
ll_matrix: (49,)
Results: [-0.06191754  0.05717472 -0.31334335 -0.66674318 -0.70001034 -0.25130795
 -0.0220832 ]


In [7]:
# compare to lifelines result
cph = CoxPHFitter()
cph.fit(rossi, duration_col='week', event_col='arrest')
cph.print_summary()  

0,1
model,lifelines.CoxPHFitter
duration col,'week'
event col,'arrest'
baseline estimation,breslow
number of observations,49
number of events observed,48
partial log-likelihood,-139.09
time fit was run,2022-08-28 10:25:58 UTC

Unnamed: 0,coef,exp(coef),se(coef),coef lower 95%,coef upper 95%,exp(coef) lower 95%,exp(coef) upper 95%,cmp to,z,p,-log2(p)
fin,-0.06,0.94,0.33,-0.71,0.59,0.49,1.81,0.0,-0.19,0.85,0.23
age,0.06,1.06,0.03,-0.0,0.12,1.0,1.12,0.0,1.93,0.05,4.21
race,-0.31,0.73,0.56,-1.4,0.78,0.25,2.17,0.0,-0.56,0.57,0.8
wexp,-0.67,0.51,0.41,-1.47,0.14,0.23,1.15,0.0,-1.62,0.1,3.26
mar,-0.7,0.5,0.66,-2.0,0.6,0.14,1.81,0.0,-1.06,0.29,1.79
paro,-0.25,0.78,0.36,-0.96,0.46,0.38,1.58,0.0,-0.69,0.49,1.03
prio,-0.02,0.98,0.06,-0.14,0.09,0.87,1.1,0.0,-0.37,0.71,0.49

0,1
Concordance,0.65
Partial AIC,292.18
log-likelihood ratio test,10.95 on 7 df
-log2(p) of ll-ratio test,2.83
