Skip to content

Commit

Permalink
Merge 3da869b into 0eef63a
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern authored Mar 31, 2020
2 parents 0eef63a + 3da869b commit 97381cf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
8 changes: 4 additions & 4 deletions convoys/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

_models = {
'kaplan-meier': lambda ci: convoys.multi.KaplanMeier(),
'exponential': lambda ci: convoys.multi.Exponential(ci=ci),
'weibull': lambda ci: convoys.multi.Weibull(ci=ci),
'gamma': lambda ci: convoys.multi.Gamma(ci=ci),
'generalized-gamma': lambda ci: convoys.multi.GeneralizedGamma(ci=ci),
'exponential': lambda ci: convoys.multi.Exponential(mcmc=ci),
'weibull': lambda ci: convoys.multi.Weibull(mcmc=ci),
'gamma': lambda ci: convoys.multi.Gamma(mcmc=ci),
'generalized-gamma': lambda ci: convoys.multi.GeneralizedGamma(mcmc=ci),
}


Expand Down
23 changes: 14 additions & 9 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class RegressionModel(object):
class GeneralizedGamma(RegressionModel):
''' Generalization of Gamma, Weibull, and Exponential
:param ci: boolean, defaults to False. Whether to use MCMC to
:param mcmc: boolean, defaults to False. Whether to use MCMC to
sample from the posterior so that a confidence interval can be
estimated later (see :meth:`predict`).
:param hierarchical: boolean denoting whether we have a (Normal) prior
Expand All @@ -86,6 +86,7 @@ class GeneralizedGamma(RegressionModel):
linear model is fit, where the beta params will be completely
additive. This creates a much more interpretable model, with some
minor loss of accuracy.
:param ci: boolean, deprecated alias for `mcmc`.
This mostly follows the `Wikipedia article
<https://en.wikipedia.org/wiki/Generalized_gamma_distribution>`_, although
Expand Down Expand Up @@ -161,17 +162,21 @@ class GeneralizedGamma(RegressionModel):
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize>`_
with the SLSQP method.
If `ci == True`, then `emcee <http://dfm.io/emcee/current/>`_ is used
If `mcmc == True`, then `emcee <http://dfm.io/emcee/current/>`_ is used
to sample from the full posterior in order to generate uncertainty
estimates for all parameters.
'''
def __init__(self, ci=False, fix_k=None, fix_p=None, hierarchical=True,
flavor='logistic'):
self._ci = ci
def __init__(self, mcmc=False, fix_k=None, fix_p=None, hierarchical=True,
flavor='logistic', ci=None):
self._mcmc = mcmc
self._fix_k = fix_k
self._fix_p = fix_p
self._hierarchical = hierarchical
self._flavor = flavor
if ci is not None:
warnings.warn('The `ci` argument is deprecated in 0.2.1 in favor of'
' `mcmc`.', DeprecationWarning)
self._mcmc = ci

def fit(self, X, B, T, W=None):
'''Fits the model.
Expand Down Expand Up @@ -241,7 +246,7 @@ def callback(LL, value_history=[]):
'Norm of gradient is %f' % gradient_norm)

# Let's sample from the posterior to compute uncertainties
if self._ci:
if self._mcmc:
dim, = res.x.shape
n_walkers = 5*dim
sampler = emcee.EnsembleSampler(
Expand Down Expand Up @@ -294,10 +299,10 @@ def _predict(self, params, x, t):
def predict_posteriori(self, x, t):
''' Returns the trace samples generated via the MCMC steps.
Requires the model to be fit with `ci = True`.'''
Requires the model to be fit with `mcmc == True`.'''
x = numpy.array(x)
t = numpy.array(t)
assert self._ci
assert self._mcmc
params = self.params['samples']
t = numpy.expand_dims(t, -1)
return self._predict(params, x, t)
Expand Down Expand Up @@ -334,7 +339,7 @@ def rvs(self, x, n_curves=1, n_samples=1, T=None):
T is optional and means we already observed non-conversion until T
'''
assert self._ci # Need to be fit with MCMC
assert self._mcmc # Need to be fit with MCMC
if T is None:
T = numpy.zeros((n_curves, n_samples))
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'''

setup(name='convoys',
version='0.2.0',
version='0.2.1',
description='Fit machine learning models to predict conversion using Weibull and Gamma distributions',
long_description=long_description,
url='https://better.engineering/convoys',
Expand Down

0 comments on commit 97381cf

Please sign in to comment.