Skip to content

Commit

Permalink
change apply signature to avoid lists
Browse files Browse the repository at this point in the history
  • Loading branch information
papajohn committed Feb 6, 2017
1 parent aa33296 commit 6217f8d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
68 changes: 35 additions & 33 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,16 @@ def column_index(self, label):
"""Return the index of a column by looking up its label."""
return self.labels.index(label)

def apply(self, fn, column_label=None):
"""Apply ``fn`` to each element of ``column_label``.
If no ``column_label`` provided, `fn`` applied to each row of table.
def apply(self, fn, *column_or_columns):
"""Apply ``fn`` to each element or elements of ``column_or_columns``.
If no ``column_or_columns`` provided, `fn`` is applied to each row.
Args:
``fn`` (function) -- The function to be applied to elements of
``column_label``.
``column_label`` (single str or array of str) -- Names of
columns to be passed into ``fn``. Length must match
number of arguments in ``fn`` signature.
``fn`` (function) -- The function to apply.
``column_or_columns``: Columns containing the arguments to ``fn``
as either column labels (``str``) or column indices (``int``).
The number of columns must match the number of arguments
that ``fn`` expects.
Raises:
``ValueError`` -- if ``column_label`` is not an existing
Expand All @@ -337,9 +337,9 @@ def apply(self, fn, column_label=None):
z | 1 | 10
>>> t.apply(lambda x: x - 1, 'points')
array([0, 1, 1, 9])
>>> t.apply(lambda x, y: x * y, make_array('count', 'points'))
>>> t.apply(lambda x, y: x * y, 'count', 'points')
array([ 9, 6, 6, 10])
>>> t.apply(lambda x: x - 1, make_array('count', 'points'))
>>> t.apply(lambda x: x - 1, 'count', 'points')
Traceback (most recent call last):
...
TypeError: <lambda>() takes 1 positional argument but 2 were given
Expand All @@ -353,21 +353,26 @@ def apply(self, fn, column_label=None):
>>> t.apply(lambda row: row[1] * 2)
array([18, 6, 6, 2])
"""
if column_label is None:
if not column_or_columns:
return np.array([fn(row) for row in self.rows])
else:
rows = zip(*self.select(column_label).columns)
if len(column_or_columns) == 1 and \
_is_non_string_iterable(column_or_columns[0]):
warnings.warn(
"column lists are deprecated; pass each as an argument", FutureWarning)
column_or_columns = column_or_columns[0]
rows = zip(*self.select(*column_or_columns).columns)
return np.array([fn(*row) for row in rows])

############
# Mutation #
############

def set_format(self, column_label_or_labels, formatter):
def set_format(self, column_or_columns, formatter):
"""Set the format of a column."""
if inspect.isclass(formatter) and issubclass(formatter, _formats.Formatter):
formatter = formatter()
for label in self._as_labels(column_label_or_labels):
for label in self._as_labels(column_or_columns):
if callable(formatter):
self._formats[label] = lambda v, label: v if label else str(formatter(v))
elif isinstance(formatter, _formats.Formatter):
Expand Down Expand Up @@ -569,7 +574,7 @@ def remove(self, row_or_row_indices):
##################

def copy(self, *, shallow=False):
"""Return a copy of a Table."""
"""Return a copy of a table."""
table = type(self)()
for label in self.labels:
if shallow:
Expand All @@ -579,23 +584,20 @@ def copy(self, *, shallow=False):
self._add_column_and_format(table, label, column)
return table

def select(self, *column_label_or_labels):
"""
Returns a new ``Table`` with only the columns in
``column_label_or_labels``.
def select(self, *column_or_columns):
"""Return a table with only the columns in ``column_or_columns``.
Args:
``column_label_or_labels``: Columns to select from the ``Table`` as
``column_or_columns``: Columns to select from the ``Table`` as
either column labels (``str``) or column indices (``int``).
Returns:
An new instance of ``Table`` containing only selected columns.
A new instance of ``Table`` containing only selected columns.
The columns of the new ``Table`` are in the order given in
``column_label_or_labels``.
``column_or_columns``.
Raises:
``KeyError`` if any of ``column_label_or_labels`` are not in the
table.
``KeyError`` if any of ``column_or_columns`` are not in the table.
>>> flowers = Table().with_columns(
... 'Number of petals', make_array(8, 34, 5),
Expand Down Expand Up @@ -627,7 +629,7 @@ def select(self, *column_label_or_labels):
34 | 5
5 | 6
"""
labels = self._varargs_as_labels(column_label_or_labels)
labels = self._varargs_as_labels(column_or_columns)
table = type(self)()
for label in labels:
self._add_column_and_format(table, label, np.copy(self[label]))
Expand All @@ -642,15 +644,15 @@ def take(self):
def exclude(self):
raise NotImplementedError()

def drop(self, *column_label_or_labels):
def drop(self, *column_or_columns):
"""Return a Table with only columns other than selected label or
labels.
Args:
``column_label_or_labels`` (string or list of strings): The header
``column_or_columns`` (string or list of strings): The header
names or indices of the columns to be dropped.
``column_label_or_labels`` must be an existing header name, or a
``column_or_columns`` must be an existing header name, or a
valid column index.
Returns:
Expand Down Expand Up @@ -696,7 +698,7 @@ def drop(self, *column_label_or_labels):
hamburger | 651
veggie burger | 582
"""
exclude = _varargs_labels_as_list(column_label_or_labels)
exclude = _varargs_labels_as_list(column_or_columns)
return self.select([c for (i, c) in enumerate(self.labels)
if i not in exclude and c not in exclude])

Expand Down Expand Up @@ -2515,12 +2517,12 @@ def _fill_with_zeros(partials, rows, zero=None):
return np.array([mapping.get(partial, zero) for partial in partials])


def _as_labels(column_label_or_labels):
def _as_labels(column_or_columns):
"""Return a list of labels for a label or labels."""
if not _is_non_string_iterable(column_label_or_labels):
return [column_label_or_labels]
if not _is_non_string_iterable(column_or_columns):
return [column_or_columns]
else:
return column_label_or_labels
return column_or_columns

def _varargs_labels_as_list(label_list):
"""Return a list of labels for a list of labels or singleton list of list
Expand Down
15 changes: 11 additions & 4 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,18 @@ def test_pivot_sum(t):

def test_apply(t):
t = t.copy()
assert_array_equal(t.apply(lambda x, y: x * y, ['count', 'points']), np.array([9, 6, 6, 10]))
assert_array_equal(t.apply(lambda x: x * x, 'points'), np.array([1, 4, 4, 100]))
assert_array_equal(t.apply(lambda row: row.item('count') * 2), np.array([18, 6, 6, 2]))
assert_array_equal(t.apply(lambda x, y: x * y, 'count', 'points'),
np.array([9, 6, 6, 10]))
assert_array_equal(t.apply(lambda x: x * x, 'points'),
np.array([1, 4, 4, 100]))
assert_array_equal(t.apply(lambda row: row.item('count') * 2),
np.array([18, 6, 6, 2]))
with(pytest.raises(ValueError)):
t.apply(lambda x, y: x + y, ['count', 'score'])
t.apply(lambda x, y: x + y, 'count', 'score')

# Deprecated behavior
assert_array_equal(t.apply(lambda x, y: x * y, ['count', 'points']),
np.array([9, 6, 6, 10]))


########
Expand Down

0 comments on commit 6217f8d

Please sign in to comment.