Skip to content

Commit

Permalink
apply to rows
Browse files Browse the repository at this point in the history
  • Loading branch information
papajohn committed Mar 23, 2016
1 parent 520172c commit 12dead0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
21 changes: 15 additions & 6 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ def column_index(self, column_label):
"""Return the index of a column."""
return self.labels.index(column_label)

def apply(self, fn, column_label):
def apply(self, fn, column_label=None):
"""Returns an array where ``fn`` is applied to each set of elements
by row from the specified columns in ``column_label``.
by row from the specified columns in ``column_label``. If no
column_label is specified, then each row is passed to fn.
Args:
``fn`` (function): The function to be applied to elements specified
Expand All @@ -302,13 +303,21 @@ def apply(self, fn, column_label):
b | 3 | 2
c | 3 | 2
z | 1 | 10
>>> t.apply(lambda x, y: x * y, ['count', 'points'])
array([ 9, 6, 6, 10])
>>> t.apply(lambda x: x - 1, 'points')
array([0, 1, 1, 9])
>>> t.apply(lambda x, y: x * y, ['count', 'points'])
array([ 9, 6, 6, 10])
Whole rows can be passed to a function as well.
>>> t.apply(lambda row: row.item('count') * 2)
array([18, 6, 6, 2])
"""
rows = zip(*self.select(column_label).columns)
return np.array([fn(*row) for row in rows])
if column_label is None:
return np.array([fn(row) for row in self.rows])
else:
rows = zip(*self.select(column_label).columns)
return np.array([fn(*row) for row in rows])

############
# Mutation #
Expand Down
1 change: 1 addition & 0 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def test_apply(t):
t = t.copy()
assert_array_equal(t.apply(lambda x, y: x * y, ['count', 'points']), np.array([9, 6, 6, 10]))
assert_array_equal(t.apply(lambda x: x * x, 'points'), np.array([1, 4, 4, 100]))
assert_array_equal(t.apply(lambda row: row.item('count') * 2), np.array([18, 6, 6, 2]))
with(pytest.raises(KeyError)):
t.apply(lambda x, y: x + y, ['count', 'score'])

Expand Down

0 comments on commit 12dead0

Please sign in to comment.