Skip to content

Commit

Permalink
remove deprecation warning related to pm.sample returning idata (#295)
Browse files Browse the repository at this point in the history
* remove deprecation warning related to pm.sample returning idata

* pylint
  • Loading branch information
aloctavodia committed Jan 26, 2021
1 parent 74c51bf commit 03ef2a8
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions bambi/backends/pymc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

import numpy as np
from arviz import from_pymc3
import theano
import pymc3 as pm

Expand Down Expand Up @@ -123,7 +122,7 @@ def run(
The method to use for fitting the model. By default, 'mcmc', in which case the
PyMC3 sampler will be used. Alternatively, 'advi', in which case the model will be
fitted using automatic differentiation variational inference as implemented in PyMC3.
Finally, 'laplace', in wich case a laplace approximation is used, 'laplace' is not
Finally, 'laplace', in which case a laplace approximation is used, 'laplace' is not
recommended other than for pedagogical use.
init: str
Initialization method (see PyMC3 sampler documentation). Currently, this is
Expand All @@ -145,9 +144,14 @@ def run(
if method.lower() == "mcmc":
draws = kwargs.pop("draws", 1000)
with model:
self.trace = pm.sample(draws, start=start, init=init, n_init=n_init, **kwargs)

idata = from_pymc3(self.trace, model=model)
idata = pm.sample(
draws,
start=start,
init=init,
n_init=n_init,
return_inferencedata=True,
**kwargs,
)

if omit_offsets:
offset_dims = [vn for vn in idata.posterior.dims if "offset" in vn]
Expand Down

0 comments on commit 03ef2a8

Please sign in to comment.