In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torchcox import TorchCox

import numpy as np
import torch
import pandas as pd

We use two datasets for validation of our Cox model: one synthetic and one real

# Validation against closed-form MLE fit

A synthetic dataset for which we can fit the Cox partial likelihood Maximum Likelihood Estimate by hand, to validate numerical PyTorch implementation against.

In [3]:
valdf = pd.DataFrame({'id':['Bob','Sally','James','Ann'], 'time':[1,3,6,10], 'status':[1,1,0,1], 'smoke':[1,0,0,1]})
valdf

Unnamed: 0,id,time,status,smoke
0,Bob,1,1,1
1,Sally,3,1,0
2,James,6,0,0
3,Ann,10,1,1


Specify model Cox(t,d ~ smoke) for times $t$ and status indicator $d$.  
$L(T, D, X) = \prod_{t_i^e} \frac{h(t_i^e,X_i)}{\sum_{j:\, t_j \geq t_i^e} h(t_j, X_j)}$  

where $t_i^e$ denote event times (i.e. times where $d=1$), and $t_i$ denotes all times (whether event or censored).  

In this way the ratio $\frac{h(t_i^e,X_i)}{\sum_{j:\, t_j \geq t_i} h(t_j, X_j)}$ is the ratio of the hazard for the event observed at time $t_i^e$ divided by the sum of hazards for all subjects at risk at that time (so with $t_j \geq t_i^e$, i.e. have not had an event yet, have not died, so still at risk, including the subject who had an event at that time since that person must have been at risk too, so inequality includes time $t_i^e$).  

We therefore see that for each of the subjects in the data just above the hazard, $h(t, X) = h_0(t)\,\exp(X \beta)$, is:  
- Bob: $h(t, X) = h_0(t)\,\exp(\beta)$
- Sally: $h(t, X) = h_0(t)\,\exp(0)$
- James: $h(t, X) = h_0(t)\,\exp(0)$
- Ann: $h(t, X) = h_0(t)\,\exp(\beta)$

The likelihood for this dataset then has three factors, one for each event time, and in each factor is the hazard of the person who had the event divided by the sum of the hazards of all the people who were still at risk then (including the person who had the event:  
  
$
\begin{align}
L(T, D, X) = &\left[\frac{h_0(t)\exp(\beta)}{h_0(t)\exp(\beta) \,+\, h_0(t)\exp(0) \,+\, h_0(t)\exp(0) \,+\, h_0(t)\exp(\beta)} \right] \\
&\times \left[ \frac{h_0(t)\exp(0)}{h_0(t)\exp(0) \,+\, h_0(t)\exp(0) \,+\, h_0(t)\exp(\beta)} \right] \\
&\times \left[ \frac{h_0(t)\exp(\beta)}{h_0(t)\exp(\beta)} \right]
\end{align}$

Note that the baseline hazard, $h_0(t)$ will always cancel everywhere in the likelihood. We then have:  
  
$\begin{align}
L(T, D, X) &= \left[\frac{\exp(\beta)}{2(1 + \exp(\beta))} \right] \times \left[ \frac{1}{2+ \exp(\beta)} \right] \times 1
\end{align}
$

Take the natural logarithm of this as the logarithm is a monotone transformation and will not change the position of the minimum, but will simplify the computation of the derivative and make the computation more numerically stable. Also we multipy it by $-1$ so that to maximise the likelihood we minimise the quantity below.

$-\ln[L(T,D,X)] = -\beta +\ln[2(1+\exp(\beta)] - \ln[1] + \ln[2+\exp(\beta)]$

Taking the derivative of this wrt $\beta$ gives  
$\begin{align}
-\frac{d}{d\beta} \ln[L(T,D,X)] &= -1 + \frac{d}{d\beta}\ln[(1+\exp(\beta)] + \frac{d}{d\beta}\ln[2+\exp(\beta)] \\
&= -1 + \frac{1}{1+\exp(\beta)}\exp(\beta) + \frac{1}{2+\exp(\beta)}\exp(\beta)
\end{align}$

Setting this equal to zero, $-\frac{d}{d\beta} \ln[L(T,D,X)] = 0$, gives  
$1 = \frac{\exp(\beta)}{1+\exp(\beta)} + \frac{\exp(\beta)}{2+\exp(\beta)}$

So that this is the equation we need to solve for $\beta$ to find the stationary point (hopefully a minimum, we haven't shown it but this score function is convex so this is indeed a minimum), and this will be the Maximum Likelihood Estimate (MLE) for $\beta$.

This gives $\beta = \ln(2)/2 \approx 0.34657$

(Note that we chose the solution in the Reals here, the solution is in fact $\beta = \ln(\sqrt{2})$ where the square root has two solutions and only the positive one gives a real value for $\beta$.)

Now we fit our Cox model code to verify that we indeed obtain the closed-form answer as expected

In [4]:
tname = 'time'
Xnames = ['smoke']
dname = 'status'

coxmod1 = TorchCox.TorchCox(lr=1)

coxmod1.fit(valdf, Xnames=Xnames, tname=tname, dname=dname)

coxmod1.beta.detach().numpy()[0]

[0.34657338]


0.34657338

And we do! :)

In [5]:
np.log(2)/2

0.34657359027997264

# Validation against R's `survival` package

Now we will compare our result against the R package 'survival' on another dataset.

In [6]:
df_tied = pd.read_csv('../data/ovarian_deduplicated.csv')
df_tied

Unnamed: 0,tyears,d,Karn,Broders,FIGO,Ascites,Diam,id
0,0.024657,1,8,2.0,IV,1,1-2cm,281
1,0.027398,1,6,,IV,1,>5cm,298
2,0.035617,1,8,2.0,IV,1,>5cm,342
3,0.041094,1,7,2.0,III,0,<1cm,228
4,0.082192,1,7,3.0,IV,1,<1cm,52
...,...,...,...,...,...,...,...,...
299,7.060274,0,10,3.0,III,0,>5cm,101
300,7.112330,0,10,1.0,III,1,>5cm,93
301,7.120548,0,9,2.0,III,1,<1cm,40
302,7.290410,0,9,2.0,IV,1,1-2cm,81


In [7]:
coxmod = TorchCox.TorchCox()

tname = 'tyears'
Xnames = ['Karn', 'Ascites']
dname = 'd'

In [8]:
#sudo pip install rpy2

In [9]:
%load_ext rpy2.ipython

In [10]:
%%R
library(readr)
library(survival)
library(dplyr)
library(tidyr)

df2 = read_csv("/home/ilan/Desktop/TorchCox/data/ovarian_deduplicated.csv")

starttime = Sys.time()

rmod = coxph(Surv(tyears, d) ~ Karn + Ascites, df2, ties="breslow")
print(coef(rmod))

endtime = Sys.time()

print(endtime-starttime)

R[write to console]: 
Attaching package: ‘dplyr’


R[write to console]: The following objects are masked from ‘package:stats’:

    filter, lag


R[write to console]: The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union


R[write to console]: Parsed with column specification:
cols(
  tyears = [32mcol_double()[39m,
  d = [32mcol_double()[39m,
  Karn = [32mcol_double()[39m,
  Broders = [32mcol_double()[39m,
  FIGO = [31mcol_character()[39m,
  Ascites = [32mcol_double()[39m,
  Diam = [31mcol_character()[39m,
  id = [32mcol_double()[39m
)



      Karn    Ascites 
-0.2364006  0.4170308 
Time difference of 0.02595592 secs


In [11]:
%%time

coxmod.fit(df_tied, Xnames=Xnames, tname=tname, dname=dname, basehaz=False)

[-0.23638311  0.4170416 ]
CPU times: user 944 ms, sys: 9.92 ms, total: 954 ms
Wall time: 276 ms


We indeed match that result as well!  

We are about 10x slower than the R package (which runs on C code in the background). But simplicity and extensibility of our code compensates for that in my view.  

(Also timings here most likely dominated by overhead of loading libraries, comparison on larger dataset required.)

To be more specific, R's survival package is a blazingly fast and mature package, however the cost of this is that the code itself is non-trivial to edit or modify, see e.g.
https://github.com/therneau/survival/blob/master/src/coxfit6.c

Contrast this with the code for this implementation which is comparably readable and can be found in `torchcox/TorchCox.py`. You can easily see how you could modify the likelihood there to insert regularisation, change the optimiser, modify the log-linear dependence on the covariates, etc.

Fit the model again, but this time computing the baseline hazard, to ensure that works as well.

In [12]:
%%time

coxmod.fit(df_tied, Xnames=Xnames, tname=tname, dname=dname, basehaz=True)

[-0.23638311  0.4170416 ]
CPU times: user 2.2 s, sys: 95 ms, total: 2.29 s
Wall time: 1.59 s


In [13]:
coxmod.basehaz

Unnamed: 0,time,h0,H0
0,0.024657240259862743,0.018764267,0.018764267
1,0.02739761676265909,0.018845245,0.0376095
2,0.03561695837826295,0.018976642,0.0565862
3,0.04109387045962334,0.019059468,0.0756456
4,0.08219248033739741,0.019129159,0.0947748
...,...,...,...
299,7.0602742018236055,1.4441799,28.2556
300,7.112329818483311,1.671192,29.9268
301,7.120547803885543,2.1946971,32.1215
302,7.290409766965798,3.6383288,35.7598


Predict on the training set to ensure predict_proba() method works as well.

In [14]:
df_tied['pred'] = coxmod.predict_proba(df_tied, Xnames=Xnames, tname=tname)
df_tied

Unnamed: 0,tyears,d,Karn,Broders,FIGO,Ascites,Diam,id,pred
0,0.024657,1,8,2.0,IV,1,1-2cm,281,0.995712
1,0.027398,1,6,,IV,1,>5cm,298,0.986277
2,0.035617,1,8,2.0,IV,1,>5cm,342,0.987125
3,0.041094,1,7,2.0,III,0,<1cm,228,0.985644
4,0.082192,1,7,3.0,IV,1,<1cm,52,0.972883
...,...,...,...,...,...,...,...,...,...
299,7.060274,0,10,3.0,III,0,>5cm,101,0.070109
300,7.112330,0,10,1.0,III,1,>5cm,93,0.013961
301,7.120548,0,9,2.0,III,1,<1cm,40,0.003005
302,7.290410,0,9,2.0,IV,1,1-2cm,81,0.001557


In summary what are some of the reasons why one might consider implementing a well-known statistical model in a differentiable programming language like PyTorch?
- Extensibility: changes to loss function or optimisation algorithm are often one-line changes
- Scalability: functionality to deploy across multiple CPUs or GPUs is often built in or easy to include
- Mobile deployment: if relevant, models can be deployed on mobile devices (Android or iOS)
- Automatic differentiation: computing confidence intervals originally involved computing second derivatives by hand and implementing the result in the code, with differentiable programming simply changing the loss is sufficient, the computation of second derivatives is automatic (provided loss is twice-differentiable, obviously)
- Ecosystem: integration with existing PyTorch libraries (see https://pytorch.org/ecosystem/) to add all sorts of functionality should be straightforward 