From 8cbf1e67881cc1b9bf44d19d7279489d4b27b3cd Mon Sep 17 00:00:00 2001 From: John DeNero Date: Sat, 11 Feb 2017 06:56:27 -0800 Subject: [PATCH 1/2] pivot accepts column index as first arg --- datascience/tables.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datascience/tables.py b/datascience/tables.py index 3c3001ac7..828e9794b 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -1041,10 +1041,10 @@ def pivot(self, columns, rows, values=None, collect=None, zero=None): the values that match both row and column based on ``collect``. Args: - ``columns`` -- a single column label, (``str``), in table, used to - create new columns, based on its unique values. - ``rows`` -- row labels, as (``str``) or array of strings, used to - create new rows based on it's unique values. + ``columns`` -- a single column label or index, (``str`` or ``int``), + used to create new columns, based on its unique values. + ``rows`` -- row labels or indices, (``str`` or ``int`` or list), + used to create new rows based on it's unique values. ``values`` -- column label in table for use in aggregation. Default None. ``collect`` -- aggregation function, used to group ``values`` @@ -1106,6 +1106,7 @@ def pivot(self, columns, rows, values=None, collect=None, zero=None): raise TypeError('collect requires values to be specified') if values is not None and collect is None: raise TypeError('values requires collect to be specified') + columns = self._as_label(columns) rows = self._as_labels(rows) if values is None: selected = self.select([columns] + rows) From 064a1662dc1fa6c78b89eed9647c5bbd2bfc7f77 Mon Sep 17 00:00:00 2001 From: John DeNero Date: Sat, 11 Feb 2017 06:58:23 -0800 Subject: [PATCH 2/2] add pivot by index test --- tests/test_tables.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_tables.py b/tests/test_tables.py index 9d36d5dbe..15756619b 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -423,6 +423,16 @@ def test_pivot_counts(t): True | 1 | 2 | 0 """) +def test_pivot_counts_with_indices(t): + t = t.copy() + t.append(('e', 12, 1, 12)) + t['early'] = t['letter'] < 'd' + test = t.pivot(2, 4) + assert_equal(test, """ + early | 1 | 2 | 10 + False | 1 | 0 | 1 + True | 1 | 2 | 0 + """) def test_pivot_values(t): t = t.copy()