diff --git a/convoys/utils.py b/convoys/utils.py index e903623..898a5b2 100644 --- a/convoys/utils.py +++ b/convoys/utils.py @@ -1,4 +1,5 @@ import datetime +import numpy import pandas __all__ = ['get_arrays'] @@ -78,8 +79,8 @@ def get_arrays(data, features=None, groups=None, created=None, G = data[groups].apply(lambda g: group2j.get(g, -1)).values res.append(G) else: - groups_list = [] - X = data[features].values + groups_list = None + X = numpy.array([numpy.array(z) for z in data[features].values]) res.append(X) # Next, construct the `B` and `T` arrays diff --git a/test_convoys.py b/test_convoys.py index d90be51..bba8d2f 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -241,6 +241,14 @@ def test_convert_dataframe(): # TODO: assert things +def test_convert_dataframe_features(n=1000): + df = _generate_dataframe(n=n) + df['features'] = [tuple(numpy.random.randn() for z in range(3)) for g in df['group']] + df = df.drop('group', axis=1) + unit, groups, (X, B, T) = convoys.utils.get_arrays(df) + assert X.shape == (n, 3) + + def _test_plot_cohorts(model='weibull', extra_model=None): df = _generate_dataframe() unit, groups, (G, B, T) = convoys.utils.get_arrays(df)