/
survivalstan.py
446 lines (385 loc) · 18.6 KB
/
survivalstan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import patsy
import stanity
import pandas as pd
import numpy as np
def fit_stan_survival_model(df, formula, event_col, model_code = None, file=None,
model_cohort = 'survival model',
time_col = None,
sample_id_col = None, sample_col = None,
group_id_col = None, group_col = None,
timepoint_id_col = None, timepoint_end_col = None,
make_inits = None, stan_data = None,
grp_coef_type = None, FIT_FUN = stanity.fit,
drop_intercept = True,
*args, **kwargs):
"""This function prepares inputs appropriate for stan model model code, and fits that model using Stan.
Args:
df (pandas DataFrame): The data frame containing input data to Survival model.
formula (chr): Patsy formula to use for covariates. E.g 'met_status + pd_l1'
event_col (chr): name of column containing event status. Will be coerced to int
model_code (chr): stan model code to use.
file (chr): path to stan file (if model_code not given).
Kwargs:
model_cohort (chr): description of this model fit, to be used when plotting or summarizing output
time_col (chr): name of column containing event time -- used for parameteric (Weibull) model
sample_id_col (chr): name of column containing numeric sample ids (1-indexed & sequential)
sample_col (chr): name of column containing sample descriptions - will be converted to an ID
group_id_col (chr): name of column containing numeric group ids (1-indexed & sequential)
group_col (chr): name of column containing group descriptions - will be converted to an ID
timepoint_id_col (chr): name of column containing timepoint ids (1-indexed & sequential)
timepoint_end_col (chr): name of column containing end times for each timepoint (will be converted to an ID)
stan_data (dict): extra params passed to stan data object
grp_coef_type (chr): type of group coef specified, if using a varying-coef model
Can be one of:
- 'None' (default): guess group coef orientation from data.
Works except in case where M (num covariates) == G (num groups)
- 'matrix': grp_beta defined as `matrix[M, G] grp_beta;`
- 'vector-of-vectors': grp_beta defined as `vector[M] grp_beta[G];`
drop_intercept (bool): whether to drop the intercept term from the model matrix (default: True)
Returns:
dictionary of results objects. Contents::
df: Pandas data frame containing input data, filtered to non-missing obs & with ID variables created
x_df: Covariate matrix passed to Stan
x_names: Column names for the covariate matrix passed to Stan
data: List passed to Stan - contains dimensions, etc.
fit: pystan fit object returned from Stan call
coefs: posterior draws for coefficient values
loo: psis-loo object returned for fit model. Used for model comparison & summary
model_cohort: description of this model and/or cohort on which the model was fit
Raises:
AttributeError, KeyError
Generic helper function for fitting variety of survival models using Stan.
Example:
>>> testfit = fit_stan_survival_model(
model_file = stanmodels.stan.pem_survival_model,
formula = '~ met_status + pd_l1',
df = dflong,
sample_col = 'patient_id',
timepoint_end_col = 'end_time',
event_col = 'end_failure',
model_cohort = 'PEM survival model',
iter = 30000,
chains = 4,
)
>>> print(testfit['fit'])
>>> seaborn.boxplot(x = 'value', y = 'variable', data = testfit['coefs'])
"""
if model_code is None:
if file is None:
raise AttributeError('Either model_code or file is required.')
## input covariates given formula
x_df = patsy.dmatrix(formula,
df,
return_type='dataframe'
)
## construct data frame with all necessary columns
## limit to non-missing data
## (if necessary) transform columns to ids
other_cols = [event_col, time_col,
group_id_col, group_col,
timepoint_id_col, timepoint_end_col,
sample_id_col, sample_col] ## list of possible columns to keep
other_cols = list(set(other_cols))
other_cols.remove(None)
if other_cols and len(other_cols)>0:
## filter other inputs to non-missing observations on input covariates
df_nonmiss = x_df.join(df[other_cols]).dropna()
else:
df_nonmiss = x_df
if len(x_df.columns)>1 and drop_intercept:
x_df = x_df.ix[:, x_df.columns != 'Intercept']
## prep input dictionary to pass to stan.fit
survival_model_input_data = {
'N': len(df_nonmiss.index),
'x': x_df.as_matrix(),
'event': df_nonmiss[event_col].values.astype(int),
'M': len(x_df.columns),
}
if time_col:
survival_model_input_data['y'] = df_nonmiss[time_col].values
## construct timepoint ID vars & add to input data
if timepoint_end_col and not(timepoint_id_col):
timepoint_id_col = 'timepoint_id'
df_nonmiss[timepoint_id_col] = df_nonmiss[timepoint_end_col].astype('category').cat.codes + 1
if timepoint_id_col:
unique_timepoints = _prep_timepoint_dataframe(df_nonmiss,
timepoint_id_col = timepoint_id_col,
timepoint_end_col = timepoint_end_col
)
timepoint_input_data = {
't_dur': unique_timepoints['t_dur'],
't_obs': unique_timepoints[timepoint_end_col],
't': df_nonmiss[timepoint_id_col].values.astype(int),
'T': len(df_nonmiss[timepoint_id_col].unique())
}
survival_model_input_data = dict(survival_model_input_data, **timepoint_input_data)
if timepoint_end_col:
# not required for all models, leave in for legacy
survival_model_input_data['obs_t'] = df_nonmiss[timepoint_end_col].values.astype(int)
## construct sample ID var & add to input data
if sample_col and not(sample_id_col):
sample_id_col = 'sample_id'
df_nonmiss[sample_id_col] = df_nonmiss[sample_col].astype('category').cat.codes + 1
if sample_id_col:
sample_input_data = {
's': df_nonmiss[sample_id_col].values.astype(int),
'S': len(df_nonmiss[sample_id_col].unique())
}
survival_model_input_data = dict(survival_model_input_data, **sample_input_data)
## construct group ID var & add to input data
if group_col and not(group_id_col):
group_id_col = 'group_id'
df_nonmiss[group_id_col] = df_nonmiss[group_col].astype('category').cat.codes + 1
if group_id_col:
survival_model_input_data['g'] = df_nonmiss[group_id_col].values.astype(int)
survival_model_input_data['G'] = len(df_nonmiss[group_id_col].unique())
if stan_data:
survival_model_input_data = dict(survival_model_input_data, **stan_data)
if make_inits:
kwargs = dict(kwargs, init = make_inits(survival_model_input_data))
survival_fit = FIT_FUN(
model_code = model_code,
file = file,
data = survival_model_input_data,
*args,
**kwargs
)
try:
beta_coefs = pd.DataFrame(
survival_fit.extract()['beta'],
columns = x_df.columns
)
beta_coefs.reset_index(0, inplace = True)
beta_coefs = beta_coefs.rename(columns = {'index':'iter'})
beta_coefs = pd.melt(beta_coefs, id_vars = ['iter'])
beta_coefs['exp(beta)'] = np.exp(beta_coefs['value'])
beta_coefs['model_cohort'] = model_cohort
except:
beta_coefs = None
## prep by-group coefs if group specified
if group_id_col:
try:
if group_col:
grp_names = df_nonmiss.loc[
~df_nonmiss[[group_id_col]].duplicated()].sort_values(group_id_col)[group_col].values
else:
grp_names = df_nonmiss.loc[
~df_nonmiss[[group_id_col]].duplicated()].sort_values(group_id_col)[group_id_col].values
grp_coefs = _extract_grp_coefs(survival_fit=survival_fit,
element='grp_beta',
grp_coef_type=grp_coef_type,
grp_names=grp_names,
columns=x_df.columns,
input_data=survival_model_input_data,
model_cohort=model_cohort
)
except:
grp_coefs = None
else:
grp_coefs = beta_coefs
if grp_coefs is not None:
grp_coefs['group'] = 'Overall'
try:
loo = stanity.psisloo(survival_fit.extract()['log_lik'])
except:
loo = None
if not sample_id_col:
sample_id_col = None
if not sample_col:
sample_col = None
if not timepoint_id_col:
timepoint_id_col = None
if not timepoint_end_col:
timepoint_end_col = None
return {
'df': df_nonmiss,
'x_df': x_df,
'x_names': x_df.columns,
'data': survival_model_input_data,
'fit': survival_fit,
'coefs': beta_coefs,
'grp_coefs': grp_coefs,
'loo': loo,
'model_cohort': model_cohort,
'df_all': df,
'sample_col': sample_col,
'sample_id_col': sample_id_col,
'timepoint_id_col': timepoint_id_col,
'timepoint_end_col': timepoint_end_col,
}
def _extract_grp_coefs(survival_fit, element, grp_coef_type, grp_names, columns, input_data, model_cohort):
""" Helper function to extract grp coefs summary data
"""
grp_coefs_extract = survival_fit.extract()[element]
## try to guess shape of group-betas
if not(grp_coef_type):
grp_coef_type = _guess_grp_coef_type(extract=grp_coefs_extract,
input_data=input_data)
## process group_coefs according to type
if grp_coef_type == 'matrix':
try:
grp_coefs_data = _format_grp_coefs_matrix(extract=grp_coefs_extract,
columns=columns,
grp_names=grp_names
)
except:
raise Exception('unable to format grp coefs as matrix')
elif grp_coef_type == 'vector-of-vectors':
try:
grp_coefs_data = _format_grp_coefs_vectors(extract=grp_coefs_extract,
columns=columns,
grp_names=grp_names
)
except:
raise Exception('unable to format grp coefs as vector-of-vectors')
elif grp_coef_type == 'unknown':
print("warning: unable to determine group-coef orientation. Try using arg `grp_coef_type`")
return(None)
else:
print("Invalid `grp_coef_type` -- must be one of 'vector-of-vectors' or 'matrix'")
print("Skipping grp coef extraction for now.")
return(None)
# process/format grp_coefs data
grp_coefs = pd.melt(grp_coefs_data, id_vars=['group','iter'])
grp_coefs['exp(beta)'] = np.exp(grp_coefs['value'])
grp_coefs['group'] = grp_coefs.group.astype('category')
grp_coefs['model_cohort'] = model_cohort
return(grp_coefs)
def _format_grp_coefs_matrix(extract, columns, grp_names):
""" Helper function for format grp_coefs data if in `matrix[M, G]` form
"""
grp_coefs_data = list()
i = 0
for grp in grp_names:
grp_data = pd.DataFrame(extract[:,:,i], columns = columns)
grp_data.reset_index(inplace=True)
grp_data.rename(columns={'index':'iter'}, inplace=True)
grp_data['group'] = grp
grp_coefs_data.append(grp_data)
i = i+1
return(pd.concat(grp_coefs_data))
def _format_grp_coefs_vectors(extract, columns, grp_names):
""" Helper function for format grp_coefs data if in `vector[M] grp_beta[G]` form
"""
grp_coefs_data = list()
i = 0
for grp in grp_names:
grp_data = pd.DataFrame(extract[:,i,:], columns = columns)
grp_data.reset_index(inplace=True)
grp_data.rename(columns={'index':'iter'}, inplace=True)
grp_data['group'] = grp
grp_coefs_data.append(grp_data)
i = i+1
return(pd.concat(grp_coefs_data))
def _guess_grp_coef_type(extract, input_data):
""" helper function to determine grp_coefs type from shape of returned object
"""
if input_data['M'] == input_data['G']:
# unable to determine shape if M == G
grp_coef_type = 'unknown'
elif extract.shape[1] == input_data['G']:
grp_coef_type = 'vector-of-vectors'
elif extract.shape[2] == input_data['G']:
grp_coef_type = 'matrix'
return grp_coef_type
def _prep_timepoint_dataframe(df,
timepoint_end_col,
timepoint_id_col = None
):
""" Helper function to take a set of timepoints
in observation-level dataframe & return
formatted timepoint_id, end_time, duration
Returns
---------
pandas dataframe with one record per timepoint_id
where timepoint_id is the index
sorted on the index, increasing
"""
time_df = df.copy()
time_df.sort_values(timepoint_end_col, inplace=True)
if not(timepoint_id_col):
timepoint_id_col = 'timepoint_id'
time_df[timepoint_id_col] = time_df[timepoint_end_col].astype('category').cat.codes + 1
time_df.dropna(how='any', subset=[timepoint_id_col, timepoint_end_col], inplace=True)
time_df = time_df.loc[:,[timepoint_id_col, timepoint_end_col]].drop_duplicates()
time_df[timepoint_end_col] = time_df[timepoint_end_col].astype(np.float32)
time_df.set_index(timepoint_id_col, inplace=True, drop=True)
time_df.sort_index(inplace=True)
t_durs = time_df.diff(periods=1)
t_durs.rename(columns = {timepoint_end_col: 't_dur'}, inplace=True)
time_df = time_df.join(t_durs)
time_df.fillna(inplace=True, value=time_df.loc[1, timepoint_end_col])
return(time_df)
def extract_grp_baseline_hazard(results, timepoint_id_col = 'timepoint_id', timepoint_end_col = 'end_time'):
""" If model results contain a grp_baseline object, extract & summarize it
"""
## TODO check if results are by-group
## TODO check if baseline hazard is computable
grp_baseline_extract = results['fit'].extract()['grp_baseline']
coef_group_names = results['grp_coefs']['group'].unique()
i = 0
grp_baseline_data = list()
for grp in coef_group_names:
grp_base = pd.DataFrame(grp_baseline_extract[:,:,i])
grp_base_coefs = pd.melt(grp_base, var_name=timepoint_id_col, value_name='baseline_hazard')
grp_base_coefs['group'] = grp
grp_baseline_data.append(grp_base_coefs)
i = i+1
grp_baseline_coefs = pd.concat(grp_baseline_data)
end_times = _extract_timepoint_end_times(results, timepoint_id_col = timepoint_id_col, timepoint_end_col = timepoint_end_col)
bs_data = pd.merge(grp_baseline_coefs, end_times, on = timepoint_id_col)
return(bs_data)
def _extract_timepoint_end_times(results, timepoint_end_col = 'end_time', timepoint_id_col = 'timepoint_id'):
df_nonmiss = results['df']
end_times = df_nonmiss.loc[~df_nonmiss[[timepoint_id_col]].duplicated()].sort_values(timepoint_id_col)[[timepoint_end_col, timepoint_id_col]]
return(end_times)
def extract_baseline_hazard(results, element='baseline', timepoint_id_col = 'timepoint_id', timepoint_end_col = 'end_time'):
""" If model results contain a baseline object, extract & summarize it
"""
## TODO check if baseline hazard is computable
baseline_extract = results['fit'].extract()[element]
baseline_coefs = pd.DataFrame(baseline_extract)
bs_coefs = pd.melt(baseline_coefs, var_name = timepoint_id_col, value_name = 'baseline_hazard')
end_times = _extract_timepoint_end_times(results, timepoint_id_col = timepoint_id_col, timepoint_end_col = timepoint_end_col)
bs_data = pd.merge(bs_coefs, end_times, on = timepoint_id_col)
bs_data['model_cohort'] = results['model_cohort']
return(bs_data)
## convert wide survival data to long format
def prep_data_long_surv(df, time_col, event_col):
''' convert wide survival data to long format
'''
## identify distinct failure/censor times
failure_times = df[time_col].unique()
ftimes = pd.DataFrame({'end_time': failure_times, 'key':1})
## cross join failure times with each observation
df['key'] = 1
dflong = pd.merge(df, ftimes, on = 'key')
## identify end-time & end-status for each sample*failure time
def gen_end_failure(row):
if row[time_col] > row['end_time']:
## event not yet occurred (time_col is after this timepoint)
return False
if row[time_col] == row['end_time']:
## event during (==) this timepoint
return row[event_col]
if row[time_col] < row['end_time']:
## event already occurred (time_col is before this timepoint)
return np.nan
dflong['end_failure'] = dflong.apply(lambda row: gen_end_failure(row), axis = 1)
## confirm total number of non-censor events hasn't changed
if not(sum(dflong.end_failure.dropna()) == sum(df[event_col].dropna())):
print('Warning: total number of events has changed from {0} to {1}'.format(sum(df[event_col]), sum(dflong.end_failure)))
## remove timepoints after failure/censor event
dflong = dflong.query('end_time <= {0}'.format(time_col))
return(dflong)
def make_weibull_survival_model_inits(stan_input_dict):
def f():
m = {
'tau_s_raw': abs(np.random.normal(0, 1)),
'tau_raw': abs(np.random.normal(0, 1, stan_input_dict['M'])),
'alpha_raw': np.random.normal(0, 0.1),
'beta_raw': np.random.normal(0, 1, stan_input_dict['M']),
'mu': np.random.normal(0, 1),
}
return m
return f