Skip to content

Commit

Permalink
Merge pull request #18 from better/plot-cohorts-test
Browse files Browse the repository at this point in the history
plot cohorts test
  • Loading branch information
erikbern committed Mar 16, 2018
2 parents 39c7d11 + 128774b commit d64d982
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
3 changes: 3 additions & 0 deletions convoys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
# PLOT
colors = seaborn.color_palette('hls', len(groups))
y_max = 0
result = []
for group, color in zip(sorted(groups), colors):
X, B, T = get_arrays(js[group], t_converter)
t = numpy.linspace(0, t_max, 1000)
Expand All @@ -157,6 +158,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final)
pyplot.plot(t, 100. * p_y, color=color, linestyle=':', alpha=0.7)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2)
result.append((group, p_y_final, p_y_lo_final, p_y_hi_final))

m_t, m_y = m.predict([1], t)
pyplot.plot(m_t, 100. * m_y, color=color, label=label)
Expand All @@ -170,6 +172,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
pyplot.ylabel('Conversion rate %')
pyplot.legend()
pyplot.gca().grid(True)
return result


def plot_timeseries(data, window, model='kaplan-meier', group_min_size=0, max_groups=100, window_min_size=1, stride=None, title=None, time=False):
Expand Down
44 changes: 21 additions & 23 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,25 @@ def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=100000):
assert 0.90*lambd < numpy.exp(model.params['alpha']) < 1.10*lambd


def _get_data(c=0.3, k=10, lambd=0.1, n=1000):
def test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=2.0, lambd=0.1, n=100000):
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
B, T = generate_censored_data(N, E, C)
data = []
now = datetime.datetime(2000, 7, 1)
for x in range(n):
date_a = datetime.datetime(2000, 1, 1) + datetime.timedelta(days=random.random()*100)
if random.random() < c:
delay = scipy.stats.gamma.rvs(a=k, scale=1.0/lambd)
date_b = date_a + datetime.timedelta(days=delay)
if date_b < now:
data.append(('foo', date_a, date_b, now))
else:
data.append(('foo', date_a, None, now))
else:
data.append(('foo', date_a, None, now))
return data


def test_plot_cohorts():
convoys.plot_cohorts(_get_data(), projection='gamma')


@pytest.mark.skip
def test_plot_conversion():
convoys.plot_timeseries(_get_data(), window=datetime.timedelta(days=7), model='gamma')
x2t = lambda x: datetime.datetime(2000, 1, 1) + datetime.timedelta(days=x)
for i, (b, t, n) in enumerate(zip(B, T, N)):
data.append(('Group %d' % (i % len(cs)), # group name
x2t(0), # created at
x2t(t) if b else None, # converted at
x2t(n))) # now

result = convoys.plot_cohorts(data, projection='weibull')
group, y, y_lo, y_hi = result[0]
c = cs[0]
k = len(data)/len(cs)
c_lo = scipy.stats.beta.ppf(0.025, k*c, k*(1-c))
c_hi = scipy.stats.beta.ppf(0.975, k*c, k*(1-c))
assert group == 'Group 0'
assert 0.95*c < y < 1.05 * c
assert 0.70*(c_hi-c_lo) < (y_hi-y_lo) < 1.30*(c_hi-c_lo)

0 comments on commit d64d982

Please sign in to comment.