In [50]:
import numpy as np
import pandas as pd
import lifelines
from lifelines.datasets import load_kidney_transplant
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
import plotly.express as px
import plotly.graph_objs as go
import os

path = os.path.realpath('..')

df = lifelines.datasets.df = lifelines.datasets.load_kidney_transplant().rename_axis('patient').reset_index()

df['start'] = 0
df['end'] = df.time

new = df[0:11]
fig = px.line(new, x='time', y='patient',color='patient', title='Survival Time of Patients')
for r,col in new.iterrows():  
    fig.add_shape(type="line", x0=new.start[r], x1=new.end[r],y0=new.index[r],y1=new.patient[r], line_width=2, line_dash="dash", line_color="blue")
fig

## KM Estimators

In [33]:
def KM(df):
   """Calculate KM estimator manually.
   Following https://allendowney.github.io/SurvivalAnalysisPython/02_kaplan_meier.html"""
   ts = df.time.unique()
   ts.sort()
   at_risk = pd.Series(0, index=ts)
   for t in ts:
      k = (t <= df['end']) # true false if time t is below or equal end time
      at_risk[t] = k.sum() # sum trues to get all cancer patients by duration
    
   d = pd.Series(0, index=ts)

   for t in ts:
      # create condition that patient died and that death was observed
      k = (df['death'] == 1) & (t == df['end'])
      # sum the true values for the above condition
      d[t] = k.sum()
      dff = pd.DataFrame(dict(death=d, 
            at_risk=at_risk), index=ts)

   dff['hazard'] = dff.death/dff.at_risk

   # calculate values for survivor curve
   dff['surv'] = (1-dff.hazard).cumprod()
   return dff

In [34]:
dff = KM(df)
dff = dff.reset_index()
fig1 = px.line(dff, x='index',  y='surv', title='Kaplan Meier Survival Function', labels={'surv':'KM_estimate', 'index':'timeline'})
fig1


In [35]:
kmf = KaplanMeierFitter()
kmf.fit(df.time, df.death)
est = kmf.survival_function_
est = est.reset_index()
fig2 = px.line(est, x='timeline',  y='KM_estimate', title='Kaplan Meier Survival Function (lifelines)')
fig2

## Confidence Intervals (Greenwood's Formula)

In [44]:

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=kmf.survival_function_.index, y=kmf.survival_function_['KM_estimate'],
    line=dict(shape='hv', width=3, color='rgb(31, 119, 180)'),
    showlegend=False
))

fig.add_trace(go.Scatter(
    x=kmf.confidence_interval_.index, 
    y=kmf.confidence_interval_['KM_estimate_upper_0.95'],
    line=dict(shape='hv', width=0),
    showlegend=False,
))

fig.add_trace(go.Scatter(
    x=kmf.confidence_interval_.index,
    y=kmf.confidence_interval_['KM_estimate_lower_0.95'],
    line=dict(shape='hv', width=0),
    fill='tonexty',
    fillcolor='rgba(31, 119, 180, 0.4)',
    showlegend=False
))

fig.update_layout(
    title= "Kaplan Meier Survival Function",
    xaxis_title="Duration",
    yaxis_title="Survival Probability"
)
fig.write_image(path +"/assets/images/CI_surv.png", scale = 5)
fig.show()

## Log Rank Test

In [49]:
# create two groups artificially
df['status'] = np.random.randint(0,2, df.shape[0])

df0 = df[df['status']==0]
df1 = df[df['status']==1]

kmf = KaplanMeierFitter()
kmf.fit(df0.time, df0.death)

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=kmf.survival_function_.index, y=kmf.survival_function_['KM_estimate'],
    line=dict(shape='hv', width=3, color='rgb(31, 119, 180)'),
    showlegend=False
))

kmf.fit(df1.time, event_observed=df1.death)

fig.add_trace(go.Scatter(
    x=kmf.survival_function_.index, y=kmf.survival_function_['KM_estimate'],
    line=dict(shape='hv', width=3, color='rgb(255,0,0)'),
    showlegend=False
))

fig.update_layout(
    title= "Comparison: Kaplan Meier Survival Functions",
    xaxis_title="Duration",
    yaxis_title="Survival Probability"
)

fig.write_image(path +"/assets/images/comp_surv.png", scale = 5)
fig.show()

In [54]:
results = logrank_test(df0['time'], df1['time'], event_observed_A=df0['status'], event_observed_B=df1['status'])
results.print_summary()

0,1
t_0,-1
null_distribution,chi squared
degrees_of_freedom,1
test_name,logrank_test

Unnamed: 0,test_statistic,p,-log2(p)
0,444.99,<0.005,325.72


In [53]:
print(results.p_value)

8.858643743309588e-99
