diff --git a/datascience/tables.py b/datascience/tables.py index 3c3001ac7..828e9794b 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -1041,10 +1041,10 @@ def pivot(self, columns, rows, values=None, collect=None, zero=None): the values that match both row and column based on ``collect``. Args: - ``columns`` -- a single column label, (``str``), in table, used to - create new columns, based on its unique values. - ``rows`` -- row labels, as (``str``) or array of strings, used to - create new rows based on it's unique values. + ``columns`` -- a single column label or index, (``str`` or ``int``), + used to create new columns, based on its unique values. + ``rows`` -- row labels or indices, (``str`` or ``int`` or list), + used to create new rows based on it's unique values. ``values`` -- column label in table for use in aggregation. Default None. ``collect`` -- aggregation function, used to group ``values`` @@ -1106,6 +1106,7 @@ def pivot(self, columns, rows, values=None, collect=None, zero=None): raise TypeError('collect requires values to be specified') if values is not None and collect is None: raise TypeError('values requires collect to be specified') + columns = self._as_label(columns) rows = self._as_labels(rows) if values is None: selected = self.select([columns] + rows) diff --git a/tests/test_tables.py b/tests/test_tables.py index 9d36d5dbe..15756619b 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -423,6 +423,16 @@ def test_pivot_counts(t): True | 1 | 2 | 0 """) +def test_pivot_counts_with_indices(t): + t = t.copy() + t.append(('e', 12, 1, 12)) + t['early'] = t['letter'] < 'd' + test = t.pivot(2, 4) + assert_equal(test, """ + early | 1 | 2 | 10 + False | 1 | 0 | 1 + True | 1 | 2 | 0 + """) def test_pivot_values(t): t = t.copy()