diff --git a/datascience/tables.py b/datascience/tables.py index 3a2486142..2921deaf6 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -22,6 +22,91 @@ import datascience.util as _util + +class _Taker: + def __init__(self, table): + self._table = table + + def __call__(self, row_numbers_or_slice): + """Return a Table of a sequence of rows taken by number. + + Args: + ``row_numbers_or_slice`` (slice or integer or list of integers): + The list of row numbers or a slice to be selected. + + Returns: + A ``Table`` containing only the selected rows. + + >>> grade = ['A+', 'A', 'A-', 'B+', 'B', 'B-'] + >>> gpa = [4, 4, 3.7, 3.3, 3, 2.7] + >>> t = Table([grade, gpa], ['letter grade', 'gpa']) + + >>> t + letter grade | gpa + A+ | 4 + A | 4 + A- | 3.7 + B+ | 3.3 + B | 3 + B- | 2.7 + >>> t.take(0) + letter grade | gpa + A+ | 4 + >>> t.take(5) + letter grade | gpa + B- | 2.7 + >>> t.take(-1) + letter grade | gpa + B- | 2.7 + >>> t.take([2, 1, 0]) + letter grade | gpa + A- | 3.7 + A | 4 + A+ | 4 + >>> print(t.take([1, 5])) + letter grade | gpa + A | 4 + B- | 2.7 + >>> t.take(range(3)) + letter grade | gpa + A+ | 4 + A | 4 + A- | 3.7 + + Note that ``take`` also supports NumPy-like indexing and slicing: + + >>> t.take[:3] + letter grade | gpa + A+ | 4 + A | 4 + A- | 3.7 + + >>> t.take[2, 1, 0] + letter grade | gpa + A- | 3.7 + A | 4 + A+ | 4 + + """ + return self[row_numbers_or_slice] + + def __getitem__(self, i): + if isinstance(i, collections.Iterable): + columns = [np.take(column, i, axis=0) + for column in self._table._columns.values()] + return self._table._with_columns(columns) + elif isinstance(i, slice): + columns = [column[i] for column in self._table._columns.values()] + return self._table._with_columns(columns) + else: + rows = self._table.rows[i] + cols = self._table._columns.keys() + if not isinstance(rows, list): + rows = [rows] + + return Table.from_rows(rows, cols) + + class Table(collections.abc.MutableMapping): """A sequence of labeled columns.""" @@ -82,6 +167,12 @@ def __init__(self, columns=None, labels=None, for column, label in zip(columns, labels): self[label] = column + self.take = _Taker(self) + + # This, along with a snippet below, is necessary for Sphinx to + # correctly load the `take` docstring + take = _Taker(None) + @classmethod def empty(cls, column_labels=None): """Create an empty table. Column labels are optional @@ -506,55 +597,6 @@ def drop(self, column_label_or_labels): """Return a Table with only columns other than selected label or labels.""" exclude = _as_labels(column_label_or_labels) return self.select([c for c in self.column_labels if c not in exclude]) - - def take(self, row_numbers): - """Return a Table of a sequence of rows taken by number. - - Args: - ``row_numbers`` (integer or list of integers): The list of row numbers to - be selected. - - Returns: - A ``Table`` containing only the selected rows. - - >>> grade = ['A+', 'A', 'A-', 'B+', 'B', 'B-'] - >>> gpa = [4, 4, 3.7, 3.3, 3, 2.7] - >>> t = Table([grade, gpa], ['letter grade', 'gpa']) - >>> t - letter grade | gpa - A+ | 4 - A | 4 - A- | 3.7 - B+ | 3.3 - B | 3 - B- | 2.7 - >>> t.take(0) - letter grade | gpa - A+ | 4 - >>> t.take(5) - letter grade | gpa - B- | 2.7 - >>> t.take(-1) - letter grade | gpa - B- | 2.7 - >>> t.take([2,1,0]) - letter grade | gpa - A- | 3.7 - A | 4 - A+ | 4 - >>> t.take([1,5]) - letter grade | gpa - A | 4 - B- | 2.7 - >>> t.take(range(3)) - letter grade | gpa - A+ | 4 - A | 4 - A- | 3.7 - """ - columns = [np.take(column, row_numbers, axis=0) for column in self.columns] - return self._with_columns(columns) - def where(self, column_or_label, value=None): """Return a Table of rows for which the column is value or a non-zero value.""" column = self._get_column(column_or_label) @@ -1601,6 +1643,10 @@ def __repr__(self): return '{0}({1})'.format(type(self).__name__, repr(self._table)) +# For Sphinx: grab the docstring from `Taker.__call__` +Table.take.__doc__ = _Taker.__call__.__doc__ + + class Q: """Query manager for Tables.""" array = None diff --git a/tests/test_tables.py b/tests/test_tables.py index 650bfb617..c94385992 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -118,6 +118,32 @@ def test_take(t): """) +def test_take_slice(t): + test = t.take[1:3] + assert_equal(test, """ + letter | count | points + b | 3 | 2 + c | 3 | 2 + """) + + +def test_take_slice_single(t): + test = t.take[1] + assert_equal(test, """ + letter | count | points + b | 3 | 2 + """) + + +def test_take_iterable(t): + test = t.take[0, 2] + assert_equal(test, """ + letter | count | points + a | 9 | 1 + c | 3 | 2 + """) + + def test_stats(t): test = t.stats() assert_equal(test, """