Skip to content

Commit

Permalink
Merge pull request #87 from stphnma/plot_specific_groups
Browse files Browse the repository at this point in the history
add argument for plotting just specific groups
  • Loading branch information
erikbern committed Dec 12, 2018
2 parents abb6d2a + f5d114f commit 6cf0b98
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
33 changes: 25 additions & 8 deletions convoys/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,42 @@


def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier',
ci=None, plot_kwargs={}, plot_ci_kwargs={}, groups=None):
ci=None, plot_kwargs={}, plot_ci_kwargs={},
groups=None, specific_groups=None):

if model not in _models.keys():
if not isinstance(model, convoys.multi.MultiModel):
raise Exception('model incorrectly specified')

if groups is None:
groups = list(set(G))

# Set x scale
if t_max is None:
_, t_max = pyplot.gca().get_xlim()
t_max = max(t_max, max(T))
if not isinstance(model, convoys.multi.MultiModel):
# Fit model
m = _models[model](ci=bool(ci))
m.fit(G, B, T)
else:
m = model

if groups is None:
groups = set(G)
if specific_groups is None:
specific_groups = groups

# Fit model
m = _models[model](ci=bool(ci))
m.fit(G, B, T)
if len(set(specific_groups).intersection(groups)) != len(specific_groups):
raise Exception('specific_groups not a subset of groups!')

# Plot
colors = pyplot.get_cmap('tab10').colors
colors = [colors[i % len(colors)] for i in range(len(groups))]
colors = [colors[i % len(colors)] for i in range(len(specific_groups))]
t = numpy.linspace(0, t_max, 1000)
_, y_max = pyplot.gca().get_ylim()
for j, (group, color) in enumerate(zip(groups, colors)):
for i, (group, color) in enumerate(zip(specific_groups, colors)):

j = groups.index(group) # matching index of group

n = sum(1 for g in G if g == j) # TODO: slow
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow
label = '%s (n=%.0f, k=%.0f)' % (group, n, k)
Expand Down
17 changes: 17 additions & 0 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,23 @@ def _test_plot_cohorts(model='weibull', extra_model=None):
if extra_model is not None else '%s.png' % model)


def test_plot_cohorts_model():
df = _generate_dataframe()
unit, groups, (G, B, T) = convoys.utils.get_arrays(df)
model = convoys.multi.Exponential(ci=None)
model.fit(G, B, T)
matplotlib.pyplot.clf()
convoys.plotting.plot_cohorts(G, B, T, model=model, groups=groups)
matplotlib.pyplot.legend()

with pytest.raises(Exception):
convoys.plotting.plot_cohorts(G, B, T, model='bad', groups=groups)

with pytest.raises(Exception):
convoys.plotting.plot_cohorts(G, B, T, model=model, groups=groups,
specific_groups=['Nonsense'])


@flaky.flaky
def test_plot_cohorts_kaplan_meier():
_test_plot_cohorts(model='kaplan-meier')
Expand Down

0 comments on commit 6cf0b98

Please sign in to comment.