diff --git a/test_convoys.py b/test_convoys.py index 530d382..8c7fbe8 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -288,6 +288,14 @@ def test_convert_dataframe_timedeltas(): assert 0 <= t2 - t1 < 3.0 / (24*60*60) +def test_convert_dataframe_more_args(): + df = _generate_dataframe() + unit, groups, (G, B, T) = convoys.utils.get_arrays(df, max_groups=2) + assert len(groups) <= 2 + unit, groups, (G, B, T) = convoys.utils.get_arrays(df, group_min_size=9999) + assert G.shape == (0,) + + def _test_plot_cohorts(model='weibull', extra_model=None): df = _generate_dataframe() unit, groups, (G, B, T) = convoys.utils.get_arrays(df)