Skip to content

Commit

Permalink
hound
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed May 28, 2018
1 parent 99f5227 commit 0d6027c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
14 changes: 8 additions & 6 deletions convoys/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import numpy
from matplotlib import pyplot
import convoys.multi
Expand All @@ -13,7 +12,8 @@
}


def plot_cohorts(G, B, T, t_max=None, title=None, model='kaplan-meier', ci=0.95, extra_model=None):
def plot_cohorts(G, B, T, t_max=None, title=None, model='kaplan-meier',
ci=0.95, extra_model=None):
# Set x scale
if t_max is None:
t_max = max(T)
Expand All @@ -32,19 +32,21 @@ def plot_cohorts(G, B, T, t_max=None, title=None, model='kaplan-meier', ci=0.95,
colors = [colors[i % len(colors)] for i in range(len(groups))]
t = numpy.linspace(0, t_max, 1000)
y_max = 0
result = []
for j, (group, color) in enumerate(zip(groups, colors)):
n = sum(1 for g in G if g == j) # TODO: slow
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow
label = '%s (n=%.0f, k=%.0f)' % (group, n, k)

if ci is not None:
p_y, p_y_lo, p_y_hi = m.cdf(j, t, ci=ci).T
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2)
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5,
alpha=0.7, label=label)
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi,
color=color, alpha=0.2)
else:
p_y = m.cdf(j, t).T
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label)
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5,
alpha=0.7, label=label)

if extra_model is not None:
extra_p_y = extra_m.cdf(j, t)
Expand Down
1 change: 1 addition & 0 deletions convoys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_arrays(data, features=None, groups=None, created=None, converted=None, n
TODO: more doc
'''
print(data.dtypes)
if groups is not None:
group2j = dict((group, j) for j, group in enumerate(get_groups(data[groups], group_min_size, max_groups)))

Expand Down
6 changes: 4 additions & 2 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ def _generate_dataframe(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000):

def test_convert_dataframe():
df = _generate_dataframe()
unit, (G, B, T) = convoys.utils.get_arrays(df, groups='groups', created='created', converted='converted', now='now')
unit, (G, B, T) = convoys.utils.get_arrays(
df, groups='groups', created='created', converted='converted', now='now')
# TODO: assert things


def _test_plot_cohorts(model='weibull', extra_model=None):
df = _generate_dataframe()
unit, (G, B, T) = convoys.utils.get_arrays(df, groups='groups', created='created', converted='converted', now='now')
unit, (G, B, T) = convoys.utils.get_arrays(
df, groups='groups', created='created', converted='converted', now='now')
matplotlib.pyplot.clf()
convoys.plotting.plot_cohorts(G, B, T, model=model, extra_model=extra_model)
matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model)
Expand Down

0 comments on commit 0d6027c

Please sign in to comment.