Skip to content

Commit

Permalink
make get_arrays return a numpy matrix for features
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Sep 1, 2019
1 parent d6187b6 commit 6689179
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
5 changes: 3 additions & 2 deletions convoys/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import numpy
import pandas

__all__ = ['get_arrays']
Expand Down Expand Up @@ -78,8 +79,8 @@ def get_arrays(data, features=None, groups=None, created=None,
G = data[groups].apply(lambda g: group2j.get(g, -1)).values
res.append(G)
else:
groups_list = []
X = data[features].values
groups_list = None
X = numpy.array([numpy.array(z) for z in data[features].values])
res.append(X)

# Next, construct the `B` and `T` arrays
Expand Down
8 changes: 8 additions & 0 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ def test_convert_dataframe():
# TODO: assert things


def test_convert_dataframe_features(n=1000):
df = _generate_dataframe(n=n)
df['features'] = [tuple(numpy.random.randn() for z in range(3)) for g in df['group']]
df = df.drop('group', axis=1)
unit, groups, (X, B, T) = convoys.utils.get_arrays(df)
assert X.shape == (n, 3)


def _test_plot_cohorts(model='weibull', extra_model=None):
df = _generate_dataframe()
unit, groups, (G, B, T) = convoys.utils.get_arrays(df)
Expand Down

0 comments on commit 6689179

Please sign in to comment.