Skip to content

Commit

Permalink
Merge pull request #51 from hammerlab/pp-surv-multiple-groups
Browse files Browse the repository at this point in the history
support multiple groups in prep-pp-survival-data - #49
  • Loading branch information
jburos committed Jan 17, 2017
2 parents 110265d + 14ebffa commit e976adf
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 10 deletions.
42 changes: 38 additions & 4 deletions survivalstan/survivalstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def fit_stan_survival_model(df=None,
raise AttributeError('Either model_code or file is required.')

if input_data is None:
input_data = SurvivalStanData(df=df, formula=formula, time_col=time_col,
input_data = SurvivalStanData(df=df,
formula=formula,
time_col=time_col,
event_col=event_col,
sample_id_col=sample_id_col,
sample_col=sample_col,
Expand Down Expand Up @@ -520,17 +522,49 @@ def extract_baseline_hazard(results, element='baseline', timepoint_id_col = 'tim
## convert wide survival data to long format
def prep_data_long_surv(df, time_col, event_col, sample_col=None,
event_name=None):
''' convert wide survival data to long format
''' Convert wide survival data df to long format, in preparation for modeling using PEM models.
If a sample_col is given, result will be de-duped so that
Returns a pandas DataFrame with original records duplicated for each unique failure time observed.
Each record will have two new columns: 'end_failure' and 'end_time', indicating
the event status (`end_failure`) for each unique timepoint (`end_time`).
Multiple events -- either per subject or multiple types of events per subject -- are
supported via optional parameters sample_col and/or event_name.
- If a sample_col is given, result will be de-duped so that
multiple events of the same type are handled correctly.
If an event_name column is given or if event_col is a list,
- If an event_name column is given or if event_col is a list,
then multiple events will be processed.
In this case, result will contain event status for each
event given. E.g. as for semi- or competing event data
with multiple event types.
**Parameters**:
:param df: Input data containing survival time & status for each subject
:type df: pandas.DataFrame
:param time_col: name of column containing time to censor/event
:type time_col: str
:param event_col: name of column containing status (1 or True: event, 0 or False: censor)
:type event_col: str
:param sample_col: (optional) column containing sample or subject identifier.
:type sample_col: str
:param event_name: (optional) column containing description of event type, if
more than one type of event is observed
:type event_name: str
**Returns**:
:return: Dataframe with original records duplicated for each unique failure time observed.
Each record will have two new columns: 'end_failure' and 'end_time', indicating
the timepoint-specific event status for each record.
:rtype: pandas.DataFrame
'''
## process multiple event_names, if given:
Expand Down
261 changes: 255 additions & 6 deletions survivalstan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,52 @@ def _summarize_survival(df, time_col, event_col, evaluate_at=None):


def extract_time_betas(models, element='beta_time', value_name='beta', **kwargs):
''' Extract posterior draws for values of time-varying `element` from each model given in the list of `models`.
Returns a pandas.DataFrame containing one record for each posterior draw of each parameter, where
the parameter varies over time.
Columns include:
- model_cohort: description of the model or cohort from which the draw was taken
- <value-column>: the value of the posterior draw, named according to given parameter `value_name`
- coef: description of the coefficient estimated, as per patsy formula provided
- iter: integer indicator of the draw from which that estimate was taken
- <timepoint-id-column>: integer identifier for each unique time at which betas are estimated
(default column name is set by `fit_stan_survival_model`, typically as "timepoint_id")
- <timepoint-end-column>: time at which this beta was estimated
(default column name is set by `fit_stan_survival_model`, typically as "end_time")
** Parameters **:
:param models: list of model-fit objects returned by `survivalstan.fit_stan_survival_model`.
:type models: list
:param element: name of parameter to extract. Defaults to "beta_time", the parameter name
used in the example time-varying stan model.
:type element: str
:param value_name: what you would like the "value" column called in the resulting dataframe
:type value_name: str
:param **kwargs: **kwargs are passed to `_extract_time_betas_single_model`, allowing
user to customize "default" values which would otherwise be read from each model object.
examples include: `coefs`, `timepoint_id_col`, and `timepoint_end_col`.
** Returns **:
:returns: pandas.DataFrame containing posterior draws of parameter values.
'''
data = [_extract_time_betas_single_model(model, element=element, value_name=value_name, **kwargs) for model in models]
return pd.concat(data)

def _extract_time_betas_single_model(stanmodel, element='beta_time', coefs=None,
value_name='beta', timepoint_id_col=None,
timepoint_end_col=None):
''' Helper/utility function used by `extract_time_betas`, for a single model
'''

if not timepoint_id_col:
timepoint_id_col = stanmodel['timepoint_id_col']
if not timepoint_end_col:
Expand Down Expand Up @@ -170,6 +210,69 @@ def plot_time_betas(models=None, df=None, element='beta_time',
subplot=None, ticks_at=None, ylabel=None, xlabel='time',
num_ticks=10, step_size=None, fill=True, alpha=0.5, pal=None,
value_name='beta', **kwargs):
''' Plot posterior draws of time-varying parameters (`element`) from each model given in the list of `models`.
.. seealso:: `extract_time_betas` to return the dataframe used by this function to plot data.
.. note:: this function can optionally take a `df` argument (the result of extract_time_betas) to
support data-extraction & plotting in a two-step operation.
** Parameters controlling data extraction **:
:param models: list of model-fit objects returned by `survivalstan.fit_stan_survival_model`.
:type models: list
:param element: name of parameter to extract. Defaults to "beta_time", the parameter name
used in the example time-varying stan model.
:type element: str
:param value_name: what you would like the "value" column called in the resulting dataframe
:type value_name: str
:param coefs: (optional) parameter passed to `extract_time_betas`, to override coefficient names
captured in `fit_stan_survival_model`.
:param timepoint_id_col: (optional) parameter passed to `extract_time_betas`, to
override timepoint_id_col captured in `fit_stan_survival_model`.
:param timepoint_end_col: (optional) parameter passed to `extract_time_betas` to
override timepoint_end_col captured in `fit_stan_survival_model`.
** Parameters controlling plot orientation/presentation **:
:param trans: (optional) function to transform y-values plotted. Example: np.log
:type trans: function
:param by: (optional) list of columns by which to aggregate & color boxplots
Defaults to: ['model_cohort', 'coef']
:type by: list
:param pal: (optional) palette to use for plotting.
:type pal: list of colors, matching length of `by` groups
:param y: (optional) column to put on the y-axis. Defaults to 'beta'
:type y: str
:param x: (optional) column to put in the x-axis. Defaults to 'timepoint_end_col'
:type x: str
:param num_ticks: (optional) how many ticks to show on the x-axis. See _plot_time_betas for details.
:param alpha: (optional) level of transparency for boxplots
:param fill: (optional) whether to fill in boxplots or just show outlines. Defaults to True
:param subplot: (optional) pyplot.subplots object to use, if provided. Useful if you want to overlay
multiple values on the same plot.
** Returns **:
:returns: Nothing. Plotted object is a side-effect.
'''
if df is None:
df = extract_time_betas(models=models, element=element, coefs=coefs,
value_name=value_name, timepoint_id_col=timepoint_id_col,
Expand Down Expand Up @@ -246,8 +349,45 @@ def _prep_pp_data_single_model(model, time_element='y_hat_time', event_element='
return pp_data


def prep_pp_data(models, time_element='y_hat_time', event_element='y_hat_event', event_col='event_status', time_col='event_time', **kwargs):
data = [_prep_pp_data_single_model(model=model, event_element=event_element, time_element=time_element, event_col=event_col, time_col=time_col, **kwargs)
def prep_pp_data(models, time_element='y_hat_time',
event_element='y_hat_event', event_col='event_status',
time_col='event_time', **kwargs):
''' Extract posterior-predicted values from each model included in the list of `models` given, optionally merged with
covariates & meta-data provided in the input `df`.
**Parameters**:
:param models: list of `fit_stan_survival_model` results from which to extract posterior-predicted values
:type models: list
:param time_element: (optional) name of parameter containing posterior-predicted event **time** for each subject
Defaults to standard used in survivalstan models: `y_hat_time`.
:type time_element: str
:param event_element: (optional) name of parameter containing posterior-predicted event **status** for each subject
Defaults to the standard used in survivalstan models: `y_hat_event`.
:type event_element: str
:param event_col: (optional) name to use for column containing posterior draw for event_status
:type event_col: str
:param time_col: (optional) name to use for column containing posterior draw for time to event
:type time_col: str
:param **kwargs: **kwargs are passed to `_prep_pp_data_single_model`, allowing user to override
or specify default values given in the original call to `fit_stan_survival_model`.
Parameters include: `sample_col`, `sample_id_col` to define names of sample description & id columns
as well as `join_with` giving name of dataframe to join with (options include df_nonmiss, x_df, or None).
Use `join_with` = None to disable merge with original dataframe.
**Returns**:
:returns: pandas.DataFrame with one record per posterior draw (iter) for each subject, from each model
optionally joined with original input data.
'''
data = [_prep_pp_data_single_model(model=model, event_element=event_element,
time_element=time_element, event_col=event_col, time_col=time_col, **kwargs)
for model in models]
data = pd.concat(data)
data.sort_values([time_col, 'iter'], inplace=True)
Expand All @@ -256,12 +396,54 @@ def prep_pp_data(models, time_element='y_hat_time', event_element='y_hat_event',

def prep_pp_survival_data(models, time_element='y_hat_time', event_element='y_hat_event',
time_col='event_time', event_col='event_status',
by=None,
**kwargs):
pp_data = prep_pp_data(models, time_element=time_element, event_element=event_element, time_col=time_col, event_col=event_col, **kwargs)
by=None, **kwargs):
''' Summarize posterior-predicted values into KM survival/censor rates
by group, for each model given in the list of `models`.
See `prep_pp_data` for details regarding process of extracting posterior-predicted values.
**Parameters**:
:param models: list of `fit_stan_survival_model` results from which to extract posterior-predicted values
:type models: list
:param by: additional column or columns by which to summarize posterior-predicted values.
Default is None, which results in draws summarized by [`iter` and `model_cohort`].
Values can include any covariates provided in the original df.
:type by: str or list of strings
:param time_element: (optional) name of parameter containing posterior-predicted event **time** for each subject
Defaults to standard used in survivalstan models: `y_hat_time`.
:type time_element: str
:param event_element: (optional) name of parameter containing posterior-predicted event **status** for each subject
Defaults to the standard used in survivalstan models: `y_hat_event`.
:type event_element: str
:param event_col: (optional) name to use for column containing posterior draw for event_status
:type event_col: str
:param time_col: (optional) name to use for column containing posterior draw for time to event
:type time_col: str
:param **kwargs: **kwargs are passed to `_prep_pp_data_single_model`, allowing user to override
or specify default values given in the original call to `fit_stan_survival_model`.
Parameters include: `sample_col`, `sample_id_col` to define names of sample description & id columns
as well as `join_with` giving name of dataframe to join with (options include df_nonmiss, x_df, or None).
Use `join_with` = None to disable merge with original dataframe.
**Returns**:
:returns: pandas.DataFrame with one record per posterior draw (iter), timepoint, model_cohort, and by-groups.
'''
pp_data = prep_pp_data(models, time_element=time_element,
event_element=event_element, time_col=time_col, event_col=event_col, **kwargs)
groups = ['iter', 'model_cohort']
if by:
if by and isinstance(by, str):
groups.append(by)
elif by and isinstance(by, list):
groups.extend(by)
pp_surv = pp_data.groupby(groups).apply(
lambda df: _summarize_survival(df, time_col=time_col, event_col=event_col))
pp_surv.reset_index(inplace=True)
Expand Down Expand Up @@ -322,6 +504,73 @@ def plot_pp_survival(models, time_element='y_hat_time', event_element='y_hat_eve
num_ticks=10, step_size=None, ticks_at=None, time_col='event_time',
event_col='event_status', fill=True, by=None, alpha=0.5, pal=None,
subplot=None, **kwargs):
''' Plot KM curve estimates from posterior-predicted values by group, for each model given in the list of `models`.
See `prep_pp_survival_data` for details regarding process of extracting posterior-predicted values.
**Parameters controlling data extraction **:
:param models: list of `fit_stan_survival_model` results from which to extract posterior-predicted values
:type models: list
:param by: additional column or columns by which to summarize posterior-predicted values.
Default is None, which results in draws summarized by [`iter` and `model_cohort`].
Values can include any covariates provided in the original df.
:type by: str or list of strings
:param time_element: (optional) name of parameter containing posterior-predicted event **time** for each subject
Defaults to standard used in survivalstan models: `y_hat_time`.
:type time_element: str
:param event_element: (optional) name of parameter containing posterior-predicted event **status** for each subject
Defaults to the standard used in survivalstan models: `y_hat_event`.
:type event_element: str
:param event_col: (optional) name to use for column containing posterior draw for event_status
:type event_col: str
:param time_col: (optional) name to use for column containing posterior draw for time to event
:type time_col: str
:param **kwargs: **kwargs are passed to `_prep_pp_data_single_model`, allowing user to override
or specify default values given in the original call to `fit_stan_survival_model`.
Parameters include: `sample_col`, `sample_id_col` to define names of sample description & id columns
as well as `join_with` giving name of dataframe to join with (options include df_nonmiss, x_df, or None).
Use `join_with` = None to disable merge with original dataframe.
** Parameters controlling plot orientation/presentation **:
:param pal: (optional) palette to use for plotting.
:type pal: list of colors, matching length of `by` groups
:param ticks_at: (optional) exact locations for placement of ticks
:param num_ticks: (optional) control number of ticks, if ticks_at not given.
:param step_size: (optional) control tick spacing, if ticks_at or num_ticks not given
:param alpha: (optional) level of transparency for boxplots
:param fill: (optional) whether to fill in boxplots or just show outlines. Defaults to True
:param subplot: (optional) pyplot.subplots object to use, if provided. Useful if you want to overlay
observed or true survival on the same plot.
:param xlabel: (optional) label for x-axis (defaults to "Days")
:param ylabel: (optional) label for y-axis (defaults to "Survival %")
:param label: (optional) legend-label for this plot group
(defaults to "posterior predictions", model-cohort, or by-group label depending options)
:param **kwargs: (optional) args passed to set properties of boxes, medians & whiskers (e.g. color)
** Returns **:
:returns: Nothing. Plotted object is a side-effect.
'''
pp_surv = prep_pp_survival_data(models, time_element=time_element,
event_element=event_element, time_col=time_col,
event_col=event_col, by=by)
Expand Down

0 comments on commit e976adf

Please sign in to comment.