Skip to content

Commit

Permalink
ModelResults.summary output is mo concise, mo infomative, mo betta
Browse files Browse the repository at this point in the history
It is now based on pm.df_summary() rather than pm.summary(). Terms with
multiple levels have the actual level names printed in the output.
Internally transformed variables are suppressed by default. Random
effects are not summarized by default (but their SDs are). Between this
and e02a276, this effectively closes #36 and definitely closes #41.

Also adds test both for this and for the traceplots and priors plots.
The tests are in test_model basically because we have to fit a model
before we can test the fitted model summary functions.
  • Loading branch information
jake-westfall committed Sep 9, 2016
1 parent 181634f commit f7f0cd1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
55 changes: 49 additions & 6 deletions bambi/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,21 @@ def __init__(self, model, trace):
def plot(self, burn_in=0, names=None, annotate=True, hide_transformed=True,
kind='trace', **kwargs):
'''
Plots posterior distributions and sample traces. Code slightly modified from:
Plots posterior distributions and sample traces. Basically a wrapper
for pm.traceplot() plus some niceties, based partly on code from:
https://pymc-devs.github.io/pymc3/notebooks/GLM-model-selection.html
Args:
burn_in (int): Number of initial samples to exclude before
summary statistics are computed.
names (list): Optional list of variable names to summarize.
annotate (bool): If True (default), add lines marking the
posterior means, write the posterior means next to the
lines, and add factor level names for fixed factors with
more than one distribution on the traceplot.
hide_transformed (bool): If True (default), do not print
summary statistics for internally transformed variables.
kind (str): Either 'trace' (default) or 'priors'. If 'priors',
this just internally calls Model.plot()
'''
if kind == 'priors':
return self.model.plot()
Expand Down Expand Up @@ -117,16 +130,46 @@ def plot(self, burn_in=0, names=None, annotate=True, hide_transformed=True,

return ax

def summary(self, burn_in=0, fixed=True, random=True, names=None,
def summary(self, burn_in=0, exclude_ranefs=True, names=None,
hide_transformed=True, **kwargs):
'''
Summarizes all parameter estimates. Currently just a wrapper for
pm.summary().
Summarizes all parameter estimates. Basically a wrapper for
pm.df_summary() plus some niceties.
Args:
burn_in (int): Number of initial samples to exclude before
summary statistics are computed.
exclude_ranefs (bool): If True (default), do not print
summary statistics for individual random effects.
names (list): Optional list of variable names to summarize.
hide_transformed (bool): If True (default), do not print
summary statistics for internally transformed variables.
'''

# if no 'names' specified, filter out unwanted variables
if names is None:
names = self.untransformed_vars if hide_transformed else self.trace.varnames

return pm.summary(self.trace[burn_in:], varnames=names, **kwargs)
if exclude_ranefs:
names = [x for x in names
if x[2:] not in list(self.model.random_terms.keys())]

# get the basic DataFrame
df = pm.df_summary(self.trace[burn_in:], varnames=names, **kwargs)

# replace the "__\d" suffixes with an informative factor level name
match = [re.match('^(.*)(?:__)(\d+)?$', x) for x in df.index]
def replace_with_name(match):
term = self.model.terms[match.group(1)[2:]]
# handle fixed effects
if term in self.model.fixed_terms.values():
return term.levels[int(match.group(2))]
# handle random effects
else:
return '{}[{}]'.format(term.name, term.levels[int(match.group(2))])
new = [replace_with_name(x) if x is not None else df.index[i]
for i,x in enumerate(match)]
df.set_index([new], inplace=True)

return df


class PyMC3ADVIResults(ModelResults):
Expand Down
10 changes: 9 additions & 1 deletion bambi/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_cell_means_with_random_intercepts(crossed_data):
model0 = Model(crossed_data)
model0.fit('Y ~ 0 + threecats', random=['subj'], run=False)
model0.build()
model0.fit(samples=1)
fitted = model0.fit(samples=10)

# using add_term
model1 = Model(crossed_data, intercept=False)
Expand Down Expand Up @@ -380,6 +380,14 @@ def test_cell_means_with_random_intercepts(crossed_data):
priors1 = {x.name:x.prior.args['sd'].args for x in model1.terms.values() if x.random}
assert set(priors0) == set(priors1)

# test summary
fitted.summary()
fitted.summary(exclude_ranefs=False)

# test plots
fitted.plot(kind='priors')
fitted.plot()


def test_random_intercepts(crossed_data):
# using formula and '1|' syntax
Expand Down

0 comments on commit f7f0cd1

Please sign in to comment.