Skip to content

Commit

Permalink
Merge 4db3a21 into 8b99997
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Dec 9, 2020
2 parents 8b99997 + 4db3a21 commit f2f8bb5
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 152 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Categorical `Term` within `Model` now have `Term.categorical` equal to `True`(#269)
* Use logging instead of warnings (#270)
* Omits ploting group-level effects and offset variables (#276)
* Logistic regression works with no explicit index (#277)

### Documentation
* Update example notebooks (#232)
Expand Down
18 changes: 5 additions & 13 deletions bambi/backends/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,16 @@ class PyMC3BackEnd(BackEnd):
dists = {"HalfFlat": pm.Bound(pm.Flat, lower=0)}

def __init__(self):
self.reset()

self.name = pm.__name__
self.version = pm.__version__

# Attributes defined elsewhere
self.model = None
self.mu = None # build()
self.spec = None # build()
self.trace = None # build()
self.advi_params = None # build()

def reset(self):
"""Reset PyMC3 model and all tracked distributions and parameters."""
self.model = pm.Model()
self.mu = None
self.par_groups = {}

def _build_dist(self, spec, label, dist, **kwargs):
"""Build and return a PyMC3 Distribution."""
if isinstance(dist, str):
Expand Down Expand Up @@ -77,18 +70,17 @@ def _expand_args(key, value, label):

return dist(label, **kwargs)

def build(self, spec, reset=True): # pylint: disable=arguments-differ
def build(self, spec): # pylint: disable=arguments-differ
"""Compile the PyMC3 model from an abstract model specification.
Parameters
----------
spec : Bambi model
A Bambi Model instance containing the abstract specification of the model to compile.
reset : Bool
If True (default), resets the PyMC3BackEnd instance before compiling.
"""
if reset:
self.reset()

coords = spec._get_pymc_coords() # pylint: disable=protected-access
self.model = pm.Model(coords=coords)

with self.model:
self.mu = 0.0
Expand Down

0 comments on commit f2f8bb5

Please sign in to comment.