From 062ba214d7121f01789f8c886d98c401f6d4687e Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 16 Oct 2020 12:02:01 -0300 Subject: [PATCH] add ploting arguments --- bambi/models.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/bambi/models.py b/bambi/models.py index 580d75077..2db5ae57e 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -1,4 +1,5 @@ # pylint: disable=no-name-in-module +# pylint: disable=too-many-lines import re import warnings from collections import OrderedDict @@ -726,7 +727,62 @@ def plot(self, draws=5000, var_names=None): warnings.warn("plot will be deprecated, please use plot_priors", FutureWarning) return self.plot_priors(draws, var_names) - def plot_priors(self, draws=5000, var_names=None): + def plot_priors( + self, + draws=5000, + var_names=None, + random_seed=None, + figsize=None, + textsize=None, + hdi_prob=None, + round_to=2, + point_estimate="mean", + kind="kde", + bins=None, + ax=None, + ): + """ + Samples from the prior distribution and plot its marginals. + + Parameters + ---------- + draws : int + Number of draws to sample from the prior predictive distribution. Defaults to 5000. + var_names : str or list + A list of names of variables for which to compute the posterior predictive + distribution. Defaults to both observed and unobserved RVs. + random_seed : int + Seed for the random number generator. + figsize: tuple + Figure size. If None it will be defined automatically. + textsize: float + Text size scaling factor for labels, titles and lines. If None it will be autoscaled + based on figsize. + hdi_prob: float, optional + Plots highest density interval for chosen percentage of density. + Use 'hide' to hide the highest density interval. Defaults to 0.94. + round_to: int, optional + Controls formatting of floats. Defaults to 2 or the integer part, whichever is bigger. + point_estimate: Optional[str] + Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None. + Defaults to 'auto' i.e. it falls back to default set in rcParams. + kind: str + Type of plot to display (kde or hist) For discrete variables this argument is ignored + and a histogram is always used. + bins: integer or sequence or 'auto', optional + Controls the number of bins, accepts the same keywords `matplotlib.hist()` does. + Only works if `kind == hist`. If None (default) it will use `auto` for continuous + variables and `range(xmin, xmax + 1)` for discrete variables. + ax: numpy array-like of matplotlib axes or bokeh figures, optional + A 2D array of locations into which to plot the densities. If not supplied, ArviZ will + create its own array of plot areas (and return it). + **kwargs + Passed as-is to plt.hist() or plt.plot() function depending on the value of `kind`. + + Returns + ------- + axes: matplotlib axes or bokeh figures + """ if not self.built: raise ValueError("Cannot plot priors until model is built!") @@ -736,9 +792,20 @@ def plot_priors(self, draws=5000, var_names=None): unobserved_rvs_names, include_transformed=False ) - pps = self.prior_predictive(draws=draws, var_names=var_names) - - axes = plot_posterior(pps, group="prior", credible_interval=None, point_estimate=None) + pps = self.prior_predictive(draws=draws, var_names=var_names, random_seed=random_seed) + + axes = plot_posterior( + pps, + group="prior", + figsize=figsize, + textsize=textsize, + hdi_prob=hdi_prob, + round_to=round_to, + point_estimate=point_estimate, + kind=kind, + bins=bins, + ax=ax, + ) return axes