diff --git a/convoys/__init__.py b/convoys/__init__.py index 3d183e4..1e66876 100644 --- a/convoys/__init__.py +++ b/convoys/__init__.py @@ -140,6 +140,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, # PLOT colors = seaborn.color_palette('hls', len(groups)) y_max = 0 + result = [] for group, color in zip(sorted(groups), colors): X, B, T = get_arrays(js[group], t_converter) t = numpy.linspace(0, t_max, 1000) @@ -157,6 +158,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, label += ' projected: %.2f%% (%.2f%% - %.2f%%)' % (100.*p_y_final, 100.*p_y_lo_final, 100.*p_y_hi_final) pyplot.plot(t, 100. * p_y, color=color, linestyle=':', alpha=0.7) pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2) + result.append((group, p_y_final, p_y_lo_final, p_y_hi_final)) m_t, m_y = m.predict([1], t) pyplot.plot(m_t, 100. * m_y, color=color, label=label) @@ -170,6 +172,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, pyplot.ylabel('Conversion rate %') pyplot.legend() pyplot.gca().grid(True) + return result def plot_timeseries(data, window, model='kaplan-meier', group_min_size=0, max_groups=100, window_min_size=1, stride=None, title=None, time=False): diff --git a/test_convoys.py b/test_convoys.py index 940afc8..50ba4d8 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -86,27 +86,25 @@ def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=100000): assert 0.90*lambd < numpy.exp(model.params['alpha']) < 1.10*lambd -def _get_data(c=0.3, k=10, lambd=0.1, n=1000): +def test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=2.0, lambd=0.1, n=100000): + C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)]) + N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,)) + E = numpy.array([sample_weibull(k, lambd) for r in range(n)]) + B, T = generate_censored_data(N, E, C) data = [] - now = datetime.datetime(2000, 7, 1) - for x in range(n): - date_a = datetime.datetime(2000, 1, 1) + datetime.timedelta(days=random.random()*100) - if random.random() < c: - delay = scipy.stats.gamma.rvs(a=k, scale=1.0/lambd) - date_b = date_a + datetime.timedelta(days=delay) - if date_b < now: - data.append(('foo', date_a, date_b, now)) - else: - data.append(('foo', date_a, None, now)) - else: - data.append(('foo', date_a, None, now)) - return data - - -def test_plot_cohorts(): - convoys.plot_cohorts(_get_data(), projection='gamma') - - -@pytest.mark.skip -def test_plot_conversion(): - convoys.plot_timeseries(_get_data(), window=datetime.timedelta(days=7), model='gamma') + x2t = lambda x: datetime.datetime(2000, 1, 1) + datetime.timedelta(days=x) + for i, (b, t, n) in enumerate(zip(B, T, N)): + data.append(('Group %d' % (i % len(cs)), # group name + x2t(0), # created at + x2t(t) if b else None, # converted at + x2t(n))) # now + + result = convoys.plot_cohorts(data, projection='weibull') + group, y, y_lo, y_hi = result[0] + c = cs[0] + k = len(data)/len(cs) + c_lo = scipy.stats.beta.ppf(0.025, k*c, k*(1-c)) + c_hi = scipy.stats.beta.ppf(0.975, k*c, k*(1-c)) + assert group == 'Group 0' + assert 0.95*c < y < 1.05 * c + assert 0.70*(c_hi-c_lo) < (y_hi-y_lo) < 1.30*(c_hi-c_lo)