Skip to content

Commit

Permalink
complete docstring coverage; closes #25
Browse files Browse the repository at this point in the history
  • Loading branch information
tyarkoni committed Aug 27, 2016
1 parent 1ee72e3 commit a6cb4d7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 5 deletions.
7 changes: 6 additions & 1 deletion bambi/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


class BackEnd(object):

'''
Base class for BackEnd hierarchy.
'''
__metaclass__ = ABCMeta

@abstractmethod
Expand All @@ -26,6 +28,9 @@ def run(self):


class PyMC3BackEnd(BackEnd):
'''
PyMC3 model-fitting back-end.
'''

# Available link functions
links = {
Expand Down
72 changes: 69 additions & 3 deletions bambi/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@


class Family(object):

'''
A specification of model family.
Args:
name (str): Family name
prior (Prior): A Prior instance specifying the model likelihood prior
link (str): The name of the link function transforming the linear
model prediction to a parameter of the likelihood
parent (str): The name of the prior parameter to set to the link-
transformed predicted outcome (e.g., mu, p, etc.).
'''
def __init__(self, name, prior, link, parent):
self.name = name
self.prior = prior
Expand All @@ -16,18 +25,63 @@ def __init__(self, name, prior, link, parent):


class Prior(object):

'''
Abstract specification of a term prior.
Args:
name (str): Name of prior distribution (e.g., Normal, Binomial, etc.)
kwargs (dict): Optional keywords specifying the parameters of the
named distribution.
'''
def __init__(self, name, **kwargs):
self.name = name
self.args = {}
self.update(**kwargs)

def update(self, **kwargs):
'''
Update the model arguments with additional arguments.
Args:
kwargs (dict): Optional keyword arguments to add to prior args.
'''
self.args.update(kwargs)


class PriorFactory(object):

'''
An object that supports specification and easy retrieval of default priors.
Args:
defaults (str, dict): Optional base configuration containing default
priors for distribution, families, and term types. If a string,
the name of a JSON file containing the config. If a dict, must
contain keys for 'dists', 'terms', and 'families'; see the built-in
JSON configuration for an example. If None, a built-in set of
priors will be used as defaults.
dists (dict): Optional specification of named distributions to use
as priors. Each key gives the name of a newly defined distribution;
values are two-element lists, where the first element is the name
of the built-in distribution to use ('Normal', 'Cauchy', etc.),
and the second element is a dictionary of parameters on that
distribution (e.g., {'mu': 0, 'sd': 10}). Priors can be nested
to arbitrary depths by replacing any parameter with another prior
specification.
terms (dict): Optional specification of default priors for different
model term types. Valid keys are 'intercept', 'fixed', or 'random'.
Values are either strings preprended by a #, in which case they
are interpreted as pointers to distributions named in the dists
dictionary, or key -> value specifications in the same format as
elements in the dists dictionary.
families (dict): Optional specification of default priors for named
family objects. Keys are family names, and values are dicts
containing mandatory keys for 'dist', 'link', and 'parent'.
Examples:
>>> dists = { 'my_dist': ['Normal', {'mu': 10, 'sd': 1000}]}
>>> pf = PriorFactory(dists=dists)
>>> families = { 'normalish': { 'dist': ['normal', {sd: '#my_dist'}],
>>> link:'identity', parent: 'mu'}}
>>> pf = PriorFactory(dists=dists, families=families)
'''
def __init__(self, defaults=None, dists=None, terms=None, families=None):

if defaults is None:
Expand Down Expand Up @@ -71,6 +125,18 @@ def _get_prior(self, spec):
return spec

def get(self, dist=None, term=None, family=None, **kwargs):
'''
Retrieve default prior for a named distribution, term type, or family.
Args:
dist (str): Name of desired distribution. Note that the name is
the key in the defaults dictionary, not the name of the
Distribution object used to construct the prior.
term (str): The type of term family to retrieve defaults for.
Must be one of 'intercept', 'fixed', or 'random'.
family (str): The name of the Family to retrieve. Must be a value
defined internally. In the default config, this is one of
'gaussian', 'binomial', 'poisson', or 't'.
'''
if dist is not None:
if dist not in self.dists:
raise ValueError(
Expand Down
20 changes: 19 additions & 1 deletion bambi/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@


class ModelResults(object):
'''
Base class for ModelResults hierarchy.
Args:
model (Model): a bambi Model instance specifying the model.
'''

__metaclass__ = ABCMeta

Expand All @@ -24,15 +29,28 @@ def summary(self):


class PyMC3ModelResults(ModelResults):

'''
Holds PyMC3 sampler results and provides plotting and summarization tools.
Args:
model (Model): a bambi Model instance specifying the model.
trace (MultiTrace): a PyMC3 MultiTrace object returned by the sampler.
'''
def __init__(self, model, trace):

self.trace = trace
self.n_samples = len(trace)

def plot(self, burn_in=0, names=None, **kwargs):
'''
Plots posterior distributions and sample traces. Currently just a
wrapper for pm.traceplot().
'''
return pm.traceplot(trace[burn_in:], varnames=names, **kwargs)

def summary(self, burn_in=0, fixed=True, random=True, names=None,
**kwargs):
'''
Summarizes all parameter estimates. Currently just a wrapper for
pm.summary().
'''
return pm.summary(trace[burn_in:], varnames=names, **kwargs)

0 comments on commit a6cb4d7

Please sign in to comment.