Skip to content

Commit

Permalink
Split the predict method into predict and predict_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Mar 30, 2020
1 parent a7a33af commit 5e80124
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 86 deletions.
34 changes: 24 additions & 10 deletions convoys/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,23 @@ def _get_x(self, group):
x[group] = 1
return x

def predict(self, group, *args, **kwargs):
return self.base_model.predict(self._get_x(group), *args, **kwargs)
def predict(self, group, t):
return self.base_model.predict(self._get_x(group), t)

def predict_ci(self, group, t, ci):
return self.base_model.predict_ci(self._get_x(group), t, ci)

def rvs(self, group, *args, **kwargs):
return self.base_model.rvs(self._get_x(group), *args, **kwargs)

@deprecated(version='0.1.8', reason='Has been renamed to :meth:`predict`')
def cdf(self, *args, **kwargs):
@deprecated(version='0.1.8',
reason='Use :meth:`predict` or :meth:`predict_ci` instead.')
def cdf(self, group, t, ci=None):
'''Returns the predicted values.'''
return self.predict(*args, **kwargs)
if ci is not None:
return self.predict_ci(group, t, ci)
else:
return self.predict(group, t)


class SingleToMulti(MultiModel):
Expand All @@ -66,13 +73,20 @@ def fit(self, G, B, T):
self._group2model[g] = self.base_model_init()
self._group2model[g].fit([b for b, t in BT], [t for b, t in BT])

def predict(self, group, t, *args, **kwargs):
return self._group2model[group].predict(t, *args, **kwargs)
def predict(self, group, t):
return self._group2model[group].predict(t)

def predict_ci(self, group, t, ci):
return self._group2model[group].predict_ci(t, ci)

@deprecated(version='0.1.8', reason='Has been renamed to :meth:`predict`')
def cdf(self, *args, **kwargs):
@deprecated(version='0.1.8',
reason='Use :meth:`predict` or :meth:`predict_ci` instead')
def cdf(self, group, t, ci=None):
'''Returns the predicted values.'''
return self.predict(*args, **kwargs)
if ci is not None:
return self.predict_ci(group, t, ci)
else:
return self.predict(group, t)


class Exponential(RegressionToMulti):
Expand Down
2 changes: 1 addition & 1 deletion convoys/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier',
label = label_fmt % dict(group=group, n=n, k=k)

if ci is not None:
p_y, p_y_lo, p_y_hi = m.predict(j, t, ci=ci).T
p_y, p_y_lo, p_y_hi = m.predict_ci(j, t, ci=ci).T
merged_plot_ci_kwargs = {'alpha': 0.2}
merged_plot_ci_kwargs.update(plot_ci_kwargs)
p = ax.fill_between(t, 100. * p_y_lo, 100. * p_y_hi,
Expand Down
88 changes: 44 additions & 44 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,25 +279,7 @@ def callback(LL, value_history=[]):
'beta': data[6+n_features:6+2*n_features].T,
} for k, data in result.items()}

def predict_posteriori(self, x, t, ci=None):
'''Returns the value of the cumulative distribution function
for a fitted model.
:param x: feature vector (or matrix)
:param t: time
:param ci: if this is provided, and the model was fit with
`ci = True`, then the return value will be the trace
samples generated via the MCMC steps. If this is not
provided, then the max a posteriori prediction will be used.
'''
x = numpy.array(x)
t = numpy.array(t)
if ci is None:
params = self.params['map']
else:
assert self._ci
params = self.params['samples']
t = numpy.expand_dims(t, -1)
def _predict(self, params, x, t):
lambd = exp(dot(x, params['alpha'].T) + params['a'])
if self._flavor == 'logistic':
c = expit(dot(x, params['beta'].T) + params['b'])
Expand All @@ -307,31 +289,45 @@ def predict_posteriori(self, x, t, ci=None):
params['k'],
(t*lambd)**params['p'])

return M
return M

def predict_posteriori(self, x, t):
''' Returns the trace samples generated via the MCMC steps.
Requires the model to be fit with `ci = True`.'''
x = numpy.array(x)
t = numpy.array(t)
assert self._ci
params = self.params['samples']
t = numpy.expand_dims(t, -1)
return self._predict(params, x, t)

def predict_ci(self, x, t, ci=0.8):
'''Works like :meth:`predict` but produces a confidence interval.
Requires the model to be fit with `ci = True`. The return value
will contain one more dimension than for :meth:`predict`, and
the last dimension will have size 3, containing the mean, the
lower bound of the confidence interval, and the upper bound of
the confidence interval.
'''
M = self.predict_posteriori(x, t)
y = numpy.mean(M, axis=-1)
y_lo = numpy.percentile(M, (1-ci)*50, axis=-1)
y_hi = numpy.percentile(M, (1+ci)*50, axis=-1)
return numpy.stack((y, y_lo, y_hi), axis=-1)

def predict(self, x, t, ci=None):
def predict(self, x, t):
'''Returns the value of the cumulative distribution function
for a fitted model.
:param x: feature vector (or matrix)
:param t: time
:param ci: if this is provided, and the model was fit with
`ci = True`, then the return value will contain one more
dimension, and the last dimension will have size 3,
containing the mean, the lower bound of the confidence
interval, and the upper bound of the confidence interval.
If this is not provided, then the max a posteriori
prediction will be used.
'''
M = self.predict_posteriori(x, t, ci)
if not ci:
return M
else:
# Replace the last axis with a 3-element vector
y = numpy.mean(M, axis=-1)
y_lo = numpy.percentile(M, (1-ci)*50, axis=-1)
y_hi = numpy.percentile(M, (1+ci)*50, axis=-1)
return numpy.stack((y, y_lo, y_hi), axis=-1)
params = self.params['map']
x = numpy.array(x)
t = numpy.array(t)
return self._predict(params, x, t)

def rvs(self, x, n_curves=1, n_samples=1, T=None):
''' Samples values from this distribution
Expand Down Expand Up @@ -366,15 +362,19 @@ def rvs(self, x, n_curves=1, n_samples=1, T=None):

return B, C

@deprecated(version='0.1.8', reason='Has been renamed to :meth:`predict`')
def cdf(self, *args, **kwargs):
@deprecated(version='0.1.8',
reason='Use :meth:`predict` or :meth:`predict_ci` instead.')
def cdf(self, x, t, ci=False):
'''Returns the predicted values.'''
return self.predict(*args, **kwargs)
if ci:
return self.predict_ci(x, t)
else:
return self.predict(x, t)

@deprecated(version='0.1.8', reason='Has been renamed to :meth:`predict`')
def cdf_posteriori(self, *args, **kwargs):
'''Returns the predicted values.'''
return self.predict_posteriori(*args, **kwargs)
@deprecated(version='0.1.8', reason='Use :meth:`predict_posteriori` instead.')
def cdf_posteriori(self, x, t):
'''Returns the a posterior distribution of the predicted values.'''
return self.predict_posteriori(x, t)


class Exponential(GeneralizedGamma):
Expand Down
55 changes: 33 additions & 22 deletions convoys/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,46 @@ def fit(self, B, T):
eps = 1e-9
self._ss_clipped = numpy.clip(self._ss, eps, 1.0-eps)

def _get_value_at(self, j, ci):
if ci:
z_lo, z_hi = scipy.stats.norm.ppf([(1-ci)/2, (1+ci)/2])
return (
1 - self._ss[j],
1 - numpy.exp(-numpy.exp(
numpy.log(-numpy.log(self._ss_clipped[j]))
+ z_hi * self._vs[j]**0.5)),
1 - numpy.exp(-numpy.exp(
numpy.log(-numpy.log(self._ss_clipped[j]))
+ z_lo * self._vs[j]**0.5))
)
else:
return 1 - self._ss[j]

def predict(self, t, ci=None):
def predict(self, t):
'''Returns the predicted values.'''
t = numpy.array(t)
res = numpy.zeros(t.shape + (3,) if ci else t.shape)
res = numpy.zeros(t.shape)
for indexes, value in numpy.ndenumerate(t):
j = numpy.searchsorted(self._ts, value, side='right') - 1
if j >= len(self._ts) - 1:
# Make the plotting stop at the last value of t
res[indexes] = [float('nan')]*3 if ci else float('nan')
res[indexes] = float('nan')
else:
res[indexes] = self._get_value_at(j, ci)
res[indexes] = 1 - self._ss[j]
return res

@deprecated(version='0.1.8', reason='Has been renamed to :meth:`predict`')
def cdf(self, *args, **kwargs):
def predict_ci(self, t, ci=0.8):
'''Returns the predicted values with a confidence interval.'''
t = numpy.array(t)
res = numpy.zeros(t.shape + (3,))
for indexes, value in numpy.ndenumerate(t):
j = numpy.searchsorted(self._ts, value, side='right') - 1
if j >= len(self._ts) - 1:
# Make the plotting stop at the last value of t
res[indexes] = [float('nan')]*3
else:
z_lo, z_hi = scipy.stats.norm.ppf([(1-ci)/2, (1+ci)/2])
res[indexes] = (
1 - self._ss[j],
1 - numpy.exp(-numpy.exp(
numpy.log(-numpy.log(self._ss_clipped[j]))
+ z_hi * self._vs[j]**0.5)),
1 - numpy.exp(-numpy.exp(
numpy.log(-numpy.log(self._ss_clipped[j]))
+ z_lo * self._vs[j]**0.5))
)
return res

@deprecated(version='0.1.8',
reason='Use :meth:`predict` or :meth:`predict_ci` instead.')
def cdf(self, t, ci=None):
'''Returns the predicted values.'''
return self.predict(*args, **kwargs)
if ci is not None:
return self.predict_ci(t)
else:
return self.predict(t)
22 changes: 13 additions & 9 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,20 @@ def test_output_shapes(c=0.3, lambd=0.1, n=1000, k=5):
assert model.predict([[X[0]], [X[1]]], [[0, 1, 2]]).shape == (2, 3)

# Generate output with ci (same as above plus (3,))
assert model.predict(X[0], 0, ci=0.8).shape == (3,)
assert model.predict([X[0], X[1]], 0, ci=0.8).shape == (2, 3)
assert model.predict([X[0]], [0, 1, 2, 3], ci=0.8).shape == (4, 3)
assert model.predict([X[0], X[1], X[2]], [0, 1, 2], ci=0.8) \
assert model.predict_ci(X[0], 0, ci=0.8).shape == (3,)
assert model.predict_ci([X[0], X[1]], 0, ci=0.8).shape == (2, 3)
assert model.predict_ci([X[0]], [0, 1, 2, 3], ci=0.8).shape == (4, 3)
assert model.predict_ci([X[0], X[1], X[2]], [0, 1, 2], ci=0.8) \
.shape == (3, 3)
assert model.predict([[X[0], X[1]]], [[0], [1], [2]], ci=0.8) \
assert model.predict_ci([[X[0], X[1]]], [[0], [1], [2]], ci=0.8) \
.shape == (3, 2, 3)
assert model.predict([[X[0]], [X[1]]], [[0, 1, 2]], ci=0.8) \
assert model.predict_ci([[X[0]], [X[1]]], [[0, 1, 2]], ci=0.8) \
.shape == (2, 3, 3)

# Assert old interface still works
assert model.cdf(X[0], 0).shape == ()
assert model.cdf(X[0], 0, ci=0.8).shape == (3,)

# Fit model without ci (should be the same)
model = convoys.regression.Exponential(ci=False)
model.fit(X, B, T)
Expand All @@ -108,9 +112,9 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=10000):
assert 0.80*c*d < model.predict([1], t) < 1.30*c*d

# Check the confidence intervals
assert model.predict([1], float('inf'), ci=0.95).shape == (3,)
assert model.predict([1], [0, 1, 2, 3], ci=0.95).shape == (4, 3)
y, y_lo, y_hi = model.predict([1], float('inf'), ci=0.95)
assert model.predict_ci([1], float('inf'), ci=0.95).shape == (3,)
assert model.predict_ci([1], [0, 1, 2, 3], ci=0.95).shape == (4, 3)
y, y_lo, y_hi = model.predict_ci([1], float('inf'), ci=0.95)
assert 0.80*c < y < 1.30*c

# Check the random variates
Expand Down

0 comments on commit 5e80124

Please sign in to comment.