Skip to content

Commit

Permalink
Sidestep plotting flat priors (#258)
Browse files Browse the repository at this point in the history
* sidestep ploting flat priors

* update changelog
  • Loading branch information
aloctavodia authored Oct 30, 2020
1 parent c21da6d commit e250295
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Remove future warning when converting the trace to InferenceData (#213)
* Include missing files for sdist (#204)
* Fixed if-else comparison that prevented HalfTStudent prior from being used (#205)
* Sidestep plotting flat priors in `plot_priors()` (#258)

### Documentation
* Update example notebooks (#232)
Expand Down
45 changes: 30 additions & 15 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,27 +786,42 @@ def plot_priors(
if not self.built:
raise ValueError("Cannot plot priors until model is built!")

unobserved_rvs_names = []
flat_rvs = []
for unobserved in self.backend.model.unobserved_RVs:
if "Flat" in unobserved.__str__():
flat_rvs.append(unobserved.name)
else:
unobserved_rvs_names.append(unobserved.name)
if var_names is None:
unobserved_rvs_names = [v.name for v in self.backend.model.unobserved_RVs]
var_names = pm.util.get_default_varnames(
unobserved_rvs_names, include_transformed=False
)
else:
flat_rvs = [fv for fv in flat_rvs if fv in var_names]
var_names = [vn for vn in var_names if vn not in flat_rvs]

pps = self.prior_predictive(draws=draws, var_names=var_names, random_seed=random_seed)

axes = plot_posterior(
pps,
group="prior",
figsize=figsize,
textsize=textsize,
hdi_prob=hdi_prob,
round_to=round_to,
point_estimate=point_estimate,
kind=kind,
bins=bins,
ax=ax,
)
if flat_rvs:
warnings.warn(
f"Variables {', '.join(flat_rvs)} have flat priors, and hence they are not plotted",
)

axes = None
if var_names:
pps = self.prior_predictive(draws=draws, var_names=var_names, random_seed=random_seed)

axes = plot_posterior(
pps,
group="prior",
figsize=figsize,
textsize=textsize,
hdi_prob=hdi_prob,
round_to=round_to,
point_estimate=point_estimate,
kind=kind,
bins=bins,
ax=ax,
)
return axes

def prior_predictive(self, draws=500, var_names=None, random_seed=None):
Expand Down

0 comments on commit e250295

Please sign in to comment.