diff --git a/convoys/plotting.py b/convoys/plotting.py index 067035a..5cb6722 100644 --- a/convoys/plotting.py +++ b/convoys/plotting.py @@ -15,10 +15,11 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', - ci=None, plot_args={}, plot_ci_args={}, groups=None): + ci=None, plot_kwargs={}, plot_ci_kwargs={}, groups=None): # Set x scale if t_max is None: - t_max = max(T) + _, t_max = pyplot.gca().get_xlim() + t_max = max(t_max, max(T)) if groups is None: groups = set(G) @@ -31,7 +32,7 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', colors = pyplot.get_cmap('tab10').colors colors = [colors[i % len(colors)] for i in range(len(groups))] t = numpy.linspace(0, t_max, 1000) - y_max = 0 + _, y_max = pyplot.gca().get_ylim() for j, (group, color) in enumerate(zip(groups, colors)): n = sum(1 for g in G if g == j) # TODO: slow k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow @@ -39,14 +40,12 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', if ci is not None: p_y, p_y_lo, p_y_hi = m.cdf(j, t, ci=ci).T - pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, - alpha=0.7, label=label, **plot_args) pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, - color=color, alpha=0.2, **plot_ci_args) + color=color, alpha=0.2, **plot_ci_kwargs) else: p_y = m.cdf(j, t).T - pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, - alpha=0.7, label=label, **plot_args) + pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, + alpha=0.7, label=label, **plot_kwargs) y_max = max(y_max, 110. * max(p_y)) diff --git a/convoys/utils.py b/convoys/utils.py index 1b758fc..af35685 100644 --- a/convoys/utils.py +++ b/convoys/utils.py @@ -16,16 +16,10 @@ def get_timedelta_converter(t_factor): if not isinstance(t, datetime.timedelta): # Assume numeric type return '', lambda x: x - elif t >= datetime.timedelta(days=1) or unit == 'Years': - return 'Years', get_timedelta_converter(1./(365.25*24*60*60)) - elif t >= datetime.timedelta(days=1) or unit == 'Days': - return 'Days', get_timedelta_converter(1./(24*60*60)) - elif t >= datetime.timedelta(hours=1) or unit == 'Hours': - return 'Hours', get_timedelta_converter(1./(60*60)) - elif t >= datetime.timedelta(minutes=1) or unit == 'Minutes': - return 'Minutes', get_timedelta_converter(1./60) - else: - return 'Seconds', get_timedelta_converter(1) + for u, f in [('years', 365.25*24*60*60), ('days', 24*60*60), + ('hours', 60*60), ('minutes', 60), ('seconds', 1)]: + if t >= datetime.timedelta(seconds=f): + return u, get_timedelta_converter(1./f) def get_groups(data, group_min_size, max_groups): @@ -45,6 +39,15 @@ def get_groups(data, group_min_size, max_groups): return sorted(groups) +def _sub(a, b): + # Computes a - b for a bunch of different cases + if isinstance(a, datetime.datetime) and a.tzinfo is not None: + return a.astimezone(b.tzinfo) - b + else: + # Either naive timestamps or numerical type + return a - b + + def get_arrays(data, features=None, groups=None, created=None, converted=None, now=None, unit=None, group_min_size=0, max_groups=-1): @@ -96,15 +99,15 @@ def get_arrays(data, features=None, groups=None, created=None, # TODO: this stuff should be vectorized, kind of ugly if not pandas.isnull(row[converted]): if created is not None: - T_raw.append(row[converted] - row[created]) + T_raw.append(_sub(row[converted], row[created])) else: T_raw.append(row[converted]) else: if created is not None: if now is not None: - T_raw.append(row[now] - row[created]) + T_raw.append(_sub(row[now], row[created])) else: - T_raw.append(datetime.datetime.now(tzinfo=row[created].tzinfo) - row[created_at]) + T_raw.append(_sub(datetime.datetime.now(), row[created])) else: T_raw.append(row[now]) unit, converter = get_timescale(max(T_raw), unit) diff --git a/docs/examples.rst b/docs/examples.rst index bbe2842..135dcce 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -25,7 +25,7 @@ Several of the arguments are references to columns in the dataframe, in this cas unit, groups, (G, B, T) = convoys.utils.get_arrays( df, groups='type', created='issue_date', converted='disposition_date', - unit='Years', group_min_size=100) + unit='years', group_min_size=100) This will create three numpy arrays that we can use to plot. @@ -70,9 +70,9 @@ Let's also plot the Kaplan-Meier and the Weibull model on top of each other so t df['bucket'] = df['issue_date'].apply(lambda d: '%d-%d' % (5*(d.year//5), 5*(d.year//5)+4)) unit, groups, (G, B, T) = convoys.utils.get_arrays( df, groups='bucket', created='issue_date', converted='disposition_date', - unit='Years', group_min_size=500) + unit='years', group_min_size=500) convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, t_max=30) - convoys.plotting.plot_cohorts(G, B, T, model='weibull', groups=groups, t_max=30, plot_args={'linestyle': '--'}, ci=0.95) + convoys.plotting.plot_cohorts(G, B, T, model='weibull', groups=groups, t_max=30, plot_kwargs={'linestyle': '--'}, ci=0.95) pyplot.legend() pyplot.show() @@ -114,7 +114,7 @@ This just gives one more degree of freedom to fit the model. pyplot.figure(figsize=(12, 9)) convoys.plotting.plot_cohorts(G, B, T, model='generalized-gamma', groups=groups) pyplot.legend() - convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, plot_args={'linestyle': '--'}) + convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, plot_kwargs={'linestyle': '--'}) pyplot.savefig('marriage-combined.png') This will generate something like this: diff --git a/docs/introduction.rst b/docs/introduction.rst index 7c511bb..86c420f 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -37,9 +37,9 @@ Luckily, there is a somewhat similar field called `survival analysis `_ is a great Python package with excellent documentation that implements many classic models for survival analysis. -Unfortunately, survival analysis assumes that *everyone dies* in the end. -This is not a realistic assumption when you model conversion rates since not everyone will convert, even given infinite amount of time. +Unfortunately, fitting a distribution such as Weibull is not enough in the case of conversion rates, since not everyone converts in the end. Typically conversion rates stabilize at some fraction eventually. +For that reason, we have to make the model a bit more complex and introduce the possibility that some items may never convert. Predicting lagged conversions ----------------------------- diff --git a/examples/dob_violations.py b/examples/dob_violations.py index 7312aa6..68a9e39 100644 --- a/examples/dob_violations.py +++ b/examples/dob_violations.py @@ -17,7 +17,7 @@ def run(): unit, groups, (G, B, T) = convoys.utils.get_arrays( df, groups='type', created='issue_date', converted='disposition_date', - unit='Years', group_min_size=100) + unit='years', group_min_size=100) for model in ['kaplan-meier', 'weibull']: print('plotting', model) @@ -25,7 +25,7 @@ def run(): convoys.plotting.plot_cohorts(G, B, T, model=model, ci=0.95, groups=groups, t_max=30) pyplot.legend() - pyplot.xlabel('Years') + pyplot.xlabel(unit) pyplot.savefig('dob-violations-%s.png' % model) pyplot.figure(figsize=(9, 6)) @@ -35,14 +35,14 @@ def run(): unit, groups, (G, B, T) = convoys.utils.get_arrays( df, groups='bucket', created='issue_date', converted='disposition_date', - unit='Years', group_min_size=500) + unit='years', group_min_size=500) convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, t_max=30) convoys.plotting.plot_cohorts(G, B, T, model='weibull', groups=groups, t_max=30, ci=0.95, - plot_args={'linestyle': '--'}) + plot_kwargs={'linestyle': '--'}) pyplot.legend() - pyplot.xlabel('Years') + pyplot.xlabel(unit) pyplot.savefig('dob-violations-combined.png') diff --git a/examples/marriage.py b/examples/marriage.py index 9b1d3ca..f0676b4 100644 --- a/examples/marriage.py +++ b/examples/marriage.py @@ -20,7 +20,7 @@ def run(): pyplot.xlabel('Age of marriage') convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, - plot_args={'linestyle': '--'}) + plot_kwargs={'linestyle': '--'}) pyplot.savefig('marriage-combined.png') diff --git a/test_convoys.py b/test_convoys.py index f27572e..c52e8ba 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -180,7 +180,7 @@ def _test_plot_cohorts(model='weibull', extra_model=None): matplotlib.pyplot.legend() if extra_model: convoys.plotting.plot_cohorts(G, B, T, model=extra_model, - plot_args=dict(linestyle='--')) + plot_kwargs=dict(linestyle='--')) matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model) if extra_model is not None else '%s.png' % model)