From 24c6d91a49e5630366121f0006d10e09740d3c39 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Fri, 6 Apr 2018 18:26:02 -0400 Subject: [PATCH] return model from plot_cohorts --- convoys/__init__.py | 25 +++++++++++++++++-------- convoys/multi.py | 24 ++++++++++++------------ test_convoys.py | 2 +- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/convoys/__init__.py b/convoys/__init__.py index 8bca7d5..3b2bf28 100644 --- a/convoys/__init__.py +++ b/convoys/__init__.py @@ -82,7 +82,7 @@ def get_groups(data, group_min_size, max_groups): } -def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier', extra_model=None): +def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier', ci=0.95, extra_model=None): # Set x scale if t_max is None: t_max = max(now - created_at for group, created_at, converted_at, now in data) @@ -109,15 +109,24 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, n = sum(1 for g in G if g == j) # TODO: slow k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow label = '%s (n=%.0f, k=%.0f)' % (group, n, k) - p_y, p_y_lo, p_y_hi = m.predict(j, t, ci=0.95).T - p_y_final, p_y_lo_final, p_y_hi_final = m.predict_final(j, ci=0.95) - label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final) - pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) - pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2) + + if ci is not None: + p_y, p_y_lo, p_y_hi = m.predict(j, t, ci=ci).T + p_y_final, p_y_lo_final, p_y_hi_final = m.predict_final(j, ci=0.95) + label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final) + result.append((group, p_y_final, p_y_lo_final, p_y_hi_final)) + pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) + pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2) + else: + p_y = m.predict(j, t).T + p_y_final = m.predict_final(j, ci=None) + label += ' projected: %.2f%%' % (100.*p_y_final,) + result.append((group, p_y_final)) + pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) + if extra_model is not None: extra_p_y = extra_m.predict(j, t) pyplot.plot(t, 100. * extra_p_y, color=color, linestyle='--', linewidth=1.5, alpha=0.7) - result.append((group, p_y_final, p_y_lo_final, p_y_hi_final)) y_max = max(y_max, 110. * max(p_y)) if title: @@ -128,4 +137,4 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, pyplot.ylabel('Conversion rate %') pyplot.legend() pyplot.gca().grid(True) - return result + return m, result diff --git a/convoys/multi.py b/convoys/multi.py index 9f19fb1..70386b3 100644 --- a/convoys/multi.py +++ b/convoys/multi.py @@ -9,14 +9,14 @@ class MultiModel: class RegressionToMulti(MultiModel): def __init__(self, *args, **kwargs): - self._base_model = self._base_model_cls(*args, **kwargs) + self.base_model = self.base_model_cls(*args, **kwargs) def fit(self, G, B, T): self._n_groups = max(G) + 1 X = numpy.zeros((len(G), self._n_groups)) for i, group in enumerate(G): X[i,group] = 1 - self._base_model.fit(X, B, T) + self.base_model.fit(X, B, T) def _get_x(self, group): x = numpy.zeros(self._n_groups) @@ -24,18 +24,18 @@ def _get_x(self, group): return x def predict(self, group, t, *args, **kwargs): - return self._base_model.predict(self._get_x(group), t, *args, **kwargs) + return self.base_model.predict(self._get_x(group), t, *args, **kwargs) def predict_final(self, group, *args, **kwargs): - return self._base_model.predict_final(self._get_x(group), *args, **kwargs) + return self.base_model.predict_final(self._get_x(group), *args, **kwargs) def predict_time(self, group, *args, **kwargs): - return self._base_model.predict_time(self._get_x(group), *args, **kwargs) + return self.base_model.predict_time(self._get_x(group), *args, **kwargs) class SingleToMulti(MultiModel): def __init__(self, *args, **kwargs): - self._base_model_init = lambda: self._base_model_cls(*args, **kwargs) + self.base_model_init = lambda: self.base_model_cls(*args, **kwargs) def fit(self, G, B, T): group2bt = {} @@ -43,7 +43,7 @@ def fit(self, G, B, T): group2bt.setdefault(g, []).append((b, t)) self._group2model = {} for g, BT in group2bt.items(): - self._group2model[g] = self._base_model_init() + self._group2model[g] = self.base_model_init() self._group2model[g].fit([b for b, t in BT], [t for b, t in BT]) def predict(self, group, t, *args, **kwargs): @@ -57,20 +57,20 @@ def predict_time(self, group, *args, **kwargs): class Exponential(RegressionToMulti): - _base_model_cls = regression.Exponential + base_model_cls = regression.Exponential class Weibull(RegressionToMulti): - _base_model_cls = regression.Weibull + base_model_cls = regression.Weibull class Gamma(RegressionToMulti): - _base_model_cls = regression.Gamma + base_model_cls = regression.Gamma class KaplanMeier(SingleToMulti): - _base_model_cls = single.KaplanMeier + base_model_cls = single.KaplanMeier class Nonparametric(SingleToMulti): - _base_model_cls = single.Nonparametric + base_model_cls = single.Nonparametric diff --git a/test_convoys.py b/test_convoys.py index 7297b55..ee0b723 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -147,7 +147,7 @@ def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=10000, model='wei x2t(n))) # now matplotlib.pyplot.clf() - result = convoys.plot_cohorts(data, model=model, extra_model=extra_model) + _, result = convoys.plot_cohorts(data, model=model, extra_model=extra_model) matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model) if extra_model is not None else '%s.png' % model) group, y, y_lo, y_hi = result[0] c = cs[0]