Skip to content

Commit

Permalink
add ploting arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 16, 2020
1 parent 3b53b1a commit 062ba21
Showing 1 changed file with 71 additions and 4 deletions.
75 changes: 71 additions & 4 deletions bambi/models.py
@@ -1,4 +1,5 @@
# pylint: disable=no-name-in-module
# pylint: disable=too-many-lines
import re
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -726,7 +727,62 @@ def plot(self, draws=5000, var_names=None):
warnings.warn("plot will be deprecated, please use plot_priors", FutureWarning)
return self.plot_priors(draws, var_names)

def plot_priors(self, draws=5000, var_names=None):
def plot_priors(
self,
draws=5000,
var_names=None,
random_seed=None,
figsize=None,
textsize=None,
hdi_prob=None,
round_to=2,
point_estimate="mean",
kind="kde",
bins=None,
ax=None,
):
"""
Samples from the prior distribution and plot its marginals.
Parameters
----------
draws : int
Number of draws to sample from the prior predictive distribution. Defaults to 5000.
var_names : str or list
A list of names of variables for which to compute the posterior predictive
distribution. Defaults to both observed and unobserved RVs.
random_seed : int
Seed for the random number generator.
figsize: tuple
Figure size. If None it will be defined automatically.
textsize: float
Text size scaling factor for labels, titles and lines. If None it will be autoscaled
based on figsize.
hdi_prob: float, optional
Plots highest density interval for chosen percentage of density.
Use 'hide' to hide the highest density interval. Defaults to 0.94.
round_to: int, optional
Controls formatting of floats. Defaults to 2 or the integer part, whichever is bigger.
point_estimate: Optional[str]
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
Defaults to 'auto' i.e. it falls back to default set in rcParams.
kind: str
Type of plot to display (kde or hist) For discrete variables this argument is ignored
and a histogram is always used.
bins: integer or sequence or 'auto', optional
Controls the number of bins, accepts the same keywords `matplotlib.hist()` does.
Only works if `kind == hist`. If None (default) it will use `auto` for continuous
variables and `range(xmin, xmax + 1)` for discrete variables.
ax: numpy array-like of matplotlib axes or bokeh figures, optional
A 2D array of locations into which to plot the densities. If not supplied, ArviZ will
create its own array of plot areas (and return it).
**kwargs
Passed as-is to plt.hist() or plt.plot() function depending on the value of `kind`.
Returns
-------
axes: matplotlib axes or bokeh figures
"""
if not self.built:
raise ValueError("Cannot plot priors until model is built!")

Expand All @@ -736,9 +792,20 @@ def plot_priors(self, draws=5000, var_names=None):
unobserved_rvs_names, include_transformed=False
)

pps = self.prior_predictive(draws=draws, var_names=var_names)

axes = plot_posterior(pps, group="prior", credible_interval=None, point_estimate=None)
pps = self.prior_predictive(draws=draws, var_names=var_names, random_seed=random_seed)

axes = plot_posterior(
pps,
group="prior",
figsize=figsize,
textsize=textsize,
hdi_prob=hdi_prob,
round_to=round_to,
point_estimate=point_estimate,
kind=kind,
bins=bins,
ax=ax,
)

return axes

Expand Down

0 comments on commit 062ba21

Please sign in to comment.