Skip to content

Commit

Permalink
adding new util - extract_params_long
Browse files Browse the repository at this point in the history
  • Loading branch information
jburos committed Jul 24, 2016
1 parent 99f2181 commit 171c1eb
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions survivalstan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,58 @@ def plot_coefs(models, element='coefs', force_direction=None):
if hue=='model_cohort':
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)


def extract_params_long(models, element, rename_vars=None, varnames=None):
"""
Helper function to extract & reformat params
Parameters
----------
models (list):
List of model objects
element (string, optional):
Which element to plot. defaults to 'coefs'.
Other options (depending on model type) include:
- 'grp_coefs'
- 'baseline_hazard'
rename_vars (dict, optional):
- dictionary mapping from integer positions (0, 1, 2) to variable names
varnames (list of strings, optional):
- list of variable names to apply to columns from the extracted object
Returns
-------
Pandas dataframe containing posterior draws per iteration
"""
df_list = list()
for model in models:
df_list.append(_extract_params_from_single_model(
model,
element = element,
rename_vars=rename_vars,
varnames=varnames
))
df_list = pd.concat(df_list)
return(df_list)


def _extract_params_from_single_model(model, element, rename_vars=None, varnames=None):
if not varnames:
df = pd.DataFrame(
model['fit'].extract()[element]
)
else:
df = pd.DataFrame(
model['fit'].extract()[element]
, columns=varnames
)
if rename_vars:
df.rename(columns = rename_vars, inplace=True)
df.reset_index(0, inplace = True)
df = df.rename(columns = {'index':'iter'})
df = pd.melt(df, id_vars = ['iter'])
df['model_cohort'] = model['model_cohort']
return(df)

0 comments on commit 171c1eb

Please sign in to comment.