Skip to content

Commit

Permalink
Merge pull request #64 from better/misc
Browse files Browse the repository at this point in the history
various fixes to units/plotting
  • Loading branch information
erikbern committed Jun 13, 2018
2 parents f630e3b + e0f4001 commit b3648b0
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 34 deletions.
15 changes: 7 additions & 8 deletions convoys/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@


def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier',
ci=None, plot_args={}, plot_ci_args={}, groups=None):
ci=None, plot_kwargs={}, plot_ci_kwargs={}, groups=None):
# Set x scale
if t_max is None:
t_max = max(T)
_, t_max = pyplot.gca().get_xlim()
t_max = max(t_max, max(T))

if groups is None:
groups = set(G)
Expand All @@ -31,22 +32,20 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier',
colors = pyplot.get_cmap('tab10').colors
colors = [colors[i % len(colors)] for i in range(len(groups))]
t = numpy.linspace(0, t_max, 1000)
y_max = 0
_, y_max = pyplot.gca().get_ylim()
for j, (group, color) in enumerate(zip(groups, colors)):
n = sum(1 for g in G if g == j) # TODO: slow
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow
label = '%s (n=%.0f, k=%.0f)' % (group, n, k)

if ci is not None:
p_y, p_y_lo, p_y_hi = m.cdf(j, t, ci=ci).T
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5,
alpha=0.7, label=label, **plot_args)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi,
color=color, alpha=0.2, **plot_ci_args)
color=color, alpha=0.2, **plot_ci_kwargs)
else:
p_y = m.cdf(j, t).T
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5,
alpha=0.7, label=label, **plot_args)
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5,
alpha=0.7, label=label, **plot_kwargs)

y_max = max(y_max, 110. * max(p_y))

Expand Down
29 changes: 16 additions & 13 deletions convoys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@ def get_timedelta_converter(t_factor):
if not isinstance(t, datetime.timedelta):
# Assume numeric type
return '', lambda x: x
elif t >= datetime.timedelta(days=1) or unit == 'Years':
return 'Years', get_timedelta_converter(1./(365.25*24*60*60))
elif t >= datetime.timedelta(days=1) or unit == 'Days':
return 'Days', get_timedelta_converter(1./(24*60*60))
elif t >= datetime.timedelta(hours=1) or unit == 'Hours':
return 'Hours', get_timedelta_converter(1./(60*60))
elif t >= datetime.timedelta(minutes=1) or unit == 'Minutes':
return 'Minutes', get_timedelta_converter(1./60)
else:
return 'Seconds', get_timedelta_converter(1)
for u, f in [('years', 365.25*24*60*60), ('days', 24*60*60),
('hours', 60*60), ('minutes', 60), ('seconds', 1)]:
if t >= datetime.timedelta(seconds=f):
return u, get_timedelta_converter(1./f)


def get_groups(data, group_min_size, max_groups):
Expand All @@ -45,6 +39,15 @@ def get_groups(data, group_min_size, max_groups):
return sorted(groups)


def _sub(a, b):
# Computes a - b for a bunch of different cases
if isinstance(a, datetime.datetime) and a.tzinfo is not None:
return a.astimezone(b.tzinfo) - b
else:
# Either naive timestamps or numerical type
return a - b


def get_arrays(data, features=None, groups=None, created=None,
converted=None, now=None, unit=None,
group_min_size=0, max_groups=-1):
Expand Down Expand Up @@ -96,15 +99,15 @@ def get_arrays(data, features=None, groups=None, created=None,
# TODO: this stuff should be vectorized, kind of ugly
if not pandas.isnull(row[converted]):
if created is not None:
T_raw.append(row[converted] - row[created])
T_raw.append(_sub(row[converted], row[created]))
else:
T_raw.append(row[converted])
else:
if created is not None:
if now is not None:
T_raw.append(row[now] - row[created])
T_raw.append(_sub(row[now], row[created]))
else:
T_raw.append(datetime.datetime.now(tzinfo=row[created].tzinfo) - row[created_at])
T_raw.append(_sub(datetime.datetime.now(), row[created]))
else:
T_raw.append(row[now])
unit, converter = get_timescale(max(T_raw), unit)
Expand Down
8 changes: 4 additions & 4 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Several of the arguments are references to columns in the dataframe, in this cas
unit, groups, (G, B, T) = convoys.utils.get_arrays(
df, groups='type', created='issue_date', converted='disposition_date',
unit='Years', group_min_size=100)
unit='years', group_min_size=100)
This will create three numpy arrays that we can use to plot.
Expand Down Expand Up @@ -70,9 +70,9 @@ Let's also plot the Kaplan-Meier and the Weibull model on top of each other so t
df['bucket'] = df['issue_date'].apply(lambda d: '%d-%d' % (5*(d.year//5), 5*(d.year//5)+4))
unit, groups, (G, B, T) = convoys.utils.get_arrays(
df, groups='bucket', created='issue_date', converted='disposition_date',
unit='Years', group_min_size=500)
unit='years', group_min_size=500)
convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, t_max=30)
convoys.plotting.plot_cohorts(G, B, T, model='weibull', groups=groups, t_max=30, plot_args={'linestyle': '--'}, ci=0.95)
convoys.plotting.plot_cohorts(G, B, T, model='weibull', groups=groups, t_max=30, plot_kwargs={'linestyle': '--'}, ci=0.95)
pyplot.legend()
pyplot.show()
Expand Down Expand Up @@ -114,7 +114,7 @@ This just gives one more degree of freedom to fit the model.
pyplot.figure(figsize=(12, 9))
convoys.plotting.plot_cohorts(G, B, T, model='generalized-gamma', groups=groups)
pyplot.legend()
convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, plot_args={'linestyle': '--'})
convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier', groups=groups, plot_kwargs={'linestyle': '--'})
pyplot.savefig('marriage-combined.png')
This will generate something like this:
Expand Down
4 changes: 2 additions & 2 deletions docs/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ Luckily, there is a somewhat similar field called `survival analysis <https://en
It introduces the concept of *censored data*, which is data that we have not observed yet.
`Lifelines <http://lifelines.readthedocs.io/en/latest/>`_ is a great Python package with excellent documentation that implements many classic models for survival analysis.

Unfortunately, survival analysis assumes that *everyone dies* in the end.
This is not a realistic assumption when you model conversion rates since not everyone will convert, even given infinite amount of time.
Unfortunately, fitting a distribution such as Weibull is not enough in the case of conversion rates, since not everyone converts in the end.
Typically conversion rates stabilize at some fraction eventually.
For that reason, we have to make the model a bit more complex and introduce the possibility that some items may never convert.

Predicting lagged conversions
-----------------------------
Expand Down
10 changes: 5 additions & 5 deletions examples/dob_violations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def run():
unit, groups, (G, B, T) = convoys.utils.get_arrays(
df, groups='type', created='issue_date',
converted='disposition_date',
unit='Years', group_min_size=100)
unit='years', group_min_size=100)

for model in ['kaplan-meier', 'weibull']:
print('plotting', model)
pyplot.figure(figsize=(9, 6))
convoys.plotting.plot_cohorts(G, B, T, model=model, ci=0.95,
groups=groups, t_max=30)
pyplot.legend()
pyplot.xlabel('Years')
pyplot.xlabel(unit)
pyplot.savefig('dob-violations-%s.png' % model)

pyplot.figure(figsize=(9, 6))
Expand All @@ -35,14 +35,14 @@ def run():
unit, groups, (G, B, T) = convoys.utils.get_arrays(
df, groups='bucket', created='issue_date',
converted='disposition_date',
unit='Years', group_min_size=500)
unit='years', group_min_size=500)
convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier',
groups=groups, t_max=30)
convoys.plotting.plot_cohorts(G, B, T, model='weibull',
groups=groups, t_max=30, ci=0.95,
plot_args={'linestyle': '--'})
plot_kwargs={'linestyle': '--'})
pyplot.legend()
pyplot.xlabel('Years')
pyplot.xlabel(unit)
pyplot.savefig('dob-violations-combined.png')


Expand Down
2 changes: 1 addition & 1 deletion examples/marriage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run():
pyplot.xlabel('Age of marriage')
convoys.plotting.plot_cohorts(G, B, T, model='kaplan-meier',
groups=groups,
plot_args={'linestyle': '--'})
plot_kwargs={'linestyle': '--'})
pyplot.savefig('marriage-combined.png')


Expand Down
2 changes: 1 addition & 1 deletion test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _test_plot_cohorts(model='weibull', extra_model=None):
matplotlib.pyplot.legend()
if extra_model:
convoys.plotting.plot_cohorts(G, B, T, model=extra_model,
plot_args=dict(linestyle='--'))
plot_kwargs=dict(linestyle='--'))
matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model)
if extra_model is not None else '%s.png' % model)

Expand Down

0 comments on commit b3648b0

Please sign in to comment.