Skip to content

Commit

Permalink
return model from plot_cohorts
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Apr 6, 2018
1 parent 12f1134 commit 24c6d91
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
25 changes: 17 additions & 8 deletions convoys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_groups(data, group_min_size, max_groups):
}


def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier', extra_model=None):
def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier', ci=0.95, extra_model=None):
# Set x scale
if t_max is None:
t_max = max(now - created_at for group, created_at, converted_at, now in data)
Expand All @@ -109,15 +109,24 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
n = sum(1 for g in G if g == j) # TODO: slow
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow
label = '%s (n=%.0f, k=%.0f)' % (group, n, k)
p_y, p_y_lo, p_y_hi = m.predict(j, t, ci=0.95).T
p_y_final, p_y_lo_final, p_y_hi_final = m.predict_final(j, ci=0.95)
label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final)
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2)

if ci is not None:
p_y, p_y_lo, p_y_hi = m.predict(j, t, ci=ci).T
p_y_final, p_y_lo_final, p_y_hi_final = m.predict_final(j, ci=0.95)
label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final)
result.append((group, p_y_final, p_y_lo_final, p_y_hi_final))
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2)
else:
p_y = m.predict(j, t).T
p_y_final = m.predict_final(j, ci=None)
label += ' projected: %.2f%%' % (100.*p_y_final,)
result.append((group, p_y_final))
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label)

if extra_model is not None:
extra_p_y = extra_m.predict(j, t)
pyplot.plot(t, 100. * extra_p_y, color=color, linestyle='--', linewidth=1.5, alpha=0.7)
result.append((group, p_y_final, p_y_lo_final, p_y_hi_final))
y_max = max(y_max, 110. * max(p_y))

if title:
Expand All @@ -128,4 +137,4 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
pyplot.ylabel('Conversion rate %')
pyplot.legend()
pyplot.gca().grid(True)
return result
return m, result
24 changes: 12 additions & 12 deletions convoys/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,41 @@ class MultiModel:

class RegressionToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self._base_model = self._base_model_cls(*args, **kwargs)
self.base_model = self.base_model_cls(*args, **kwargs)

def fit(self, G, B, T):
self._n_groups = max(G) + 1
X = numpy.zeros((len(G), self._n_groups))
for i, group in enumerate(G):
X[i,group] = 1
self._base_model.fit(X, B, T)
self.base_model.fit(X, B, T)

def _get_x(self, group):
x = numpy.zeros(self._n_groups)
x[group] = 1
return x

def predict(self, group, t, *args, **kwargs):
return self._base_model.predict(self._get_x(group), t, *args, **kwargs)
return self.base_model.predict(self._get_x(group), t, *args, **kwargs)

def predict_final(self, group, *args, **kwargs):
return self._base_model.predict_final(self._get_x(group), *args, **kwargs)
return self.base_model.predict_final(self._get_x(group), *args, **kwargs)

def predict_time(self, group, *args, **kwargs):
return self._base_model.predict_time(self._get_x(group), *args, **kwargs)
return self.base_model.predict_time(self._get_x(group), *args, **kwargs)


class SingleToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self._base_model_init = lambda: self._base_model_cls(*args, **kwargs)
self.base_model_init = lambda: self.base_model_cls(*args, **kwargs)

def fit(self, G, B, T):
group2bt = {}
for g, b, t in zip(G, B, T):
group2bt.setdefault(g, []).append((b, t))
self._group2model = {}
for g, BT in group2bt.items():
self._group2model[g] = self._base_model_init()
self._group2model[g] = self.base_model_init()
self._group2model[g].fit([b for b, t in BT], [t for b, t in BT])

def predict(self, group, t, *args, **kwargs):
Expand All @@ -57,20 +57,20 @@ def predict_time(self, group, *args, **kwargs):


class Exponential(RegressionToMulti):
_base_model_cls = regression.Exponential
base_model_cls = regression.Exponential


class Weibull(RegressionToMulti):
_base_model_cls = regression.Weibull
base_model_cls = regression.Weibull


class Gamma(RegressionToMulti):
_base_model_cls = regression.Gamma
base_model_cls = regression.Gamma


class KaplanMeier(SingleToMulti):
_base_model_cls = single.KaplanMeier
base_model_cls = single.KaplanMeier


class Nonparametric(SingleToMulti):
_base_model_cls = single.Nonparametric
base_model_cls = single.Nonparametric
2 changes: 1 addition & 1 deletion test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=10000, model='wei
x2t(n))) # now

matplotlib.pyplot.clf()
result = convoys.plot_cohorts(data, model=model, extra_model=extra_model)
_, result = convoys.plot_cohorts(data, model=model, extra_model=extra_model)
matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model) if extra_model is not None else '%s.png' % model)
group, y, y_lo, y_hi = result[0]
c = cs[0]
Expand Down

0 comments on commit 24c6d91

Please sign in to comment.