Skip to content

Commit

Permalink
refactored ModelResults into separate classes for different BackEnds
Browse files Browse the repository at this point in the history
  • Loading branch information
tyarkoni committed Aug 27, 2016
1 parent 1cd4d45 commit 100e3ad
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
6 changes: 3 additions & 3 deletions bambi/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bambi.external.six import string_types
import numpy as np
import warnings
from bambi.results import ModelResults
from bambi.results import PyMC3ModelResults
from bambi.priors import Prior
import theano
try:
Expand Down Expand Up @@ -122,11 +122,11 @@ def run(self, start=None, find_map=False, **kwargs):
find_map (bool): whether or not to use the maximum a posteriori
estimate as a starting point; passed directly to PyMC3.
kwargs (dict): Optional keyword arguments passed onto the sampler.
Returns: A PyMC3ModelResults instance.
'''
samples = kwargs.pop('samples', 1000)
with self.model:
if start is None and find_map:
start = pm.find_MAP()
self.trace = pm.sample(samples, start=start, **kwargs)
return ModelResults(self.spec, self.trace)
return PyMC3ModelResults(self.spec, self.trace)
48 changes: 25 additions & 23 deletions bambi/results.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
import pandas as pd
import pymc3 as pm
from abc import abstractmethod, ABCMeta


class ModelResults(object):

def __init__(self, model, trace):
__metaclass__ = ABCMeta

def __init__(self, model):

self.model = model
self.terms = list(model.terms.values())
self.trace = trace
self.diagnostics = model._diagnostics
self.n_terms = len(model.terms)

@abstractmethod
def plot(self):
pass

@abstractmethod
def summary(self):
pass


class PyMC3ModelResults(ModelResults):

def __init__(self, model, trace):

self.trace = trace
self.n_samples = len(trace)
self._fixed_terms = [t.name for t in self.terms if t.type_=='fixed']
self._random_terms = [t.name for t in self.terms if t.type_=='random']

def _select_samples(self, fixed, random, names, burn_in):
trace = self.trace[burn_in:]
if names is not None:
names = []
if fixed:
names.extend(self._fixed_terms)
if random:
names.extend(self._random_terms)
return trace, names

def plot_trace(self, burn_in=0, fixed=True, random=True, names=None,
**kwargs):
trace, names = self._select_samples(fixed, random, names, burn_in)
return pm.traceplot(trace, varnames=names, **kwargs)

def summary(self, burn_in=0, fixed=True, random=True, names=None, **kwargs):
trace, names = self._select_samples(fixed, random, names, burn_in)
return pm.summary(trace, varnames=names, **kwargs)

def plot(self, burn_in=0, names=None, **kwargs):
return pm.traceplot(trace[burn_in:], varnames=names, **kwargs)

def summary(self, burn_in=0, fixed=True, random=True, names=None,
**kwargs):
return pm.summary(trace[burn_in:], varnames=names, **kwargs)

0 comments on commit 100e3ad

Please sign in to comment.