Skip to content

Commit

Permalink
change default group function
Browse files Browse the repository at this point in the history
  • Loading branch information
papajohn committed Jan 29, 2016
1 parent 25ce478 commit b3107e5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
19 changes: 11 additions & 8 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def _with_columns(self, columns):

def _add_column_and_format(self, table, label, column):
"""Add a column to table, copying the formatter from self."""
label = self._as_label(label)
table[label] = column
if label in self._formats:
table._formats[label] = self._formats[label]
Expand Down Expand Up @@ -490,12 +491,11 @@ def apply(self, fn, column_label):
array([0, 1, 1, 9])
"""
#return np.array([fn(v) for v in self[column_label]])
if isinstance(column_label, str):
column_label = [column_label]
for c in column_label:
labels = [self._as_label(s) for s in _as_labels(column_label)]
for c in labels:
if not (c in self.labels):
raise ValueError("{} is not an existing column in the table".format(c))
return np.array([fn(*[self.take(i)[col][0] for col in column_label]) for i in range(self.num_rows)])
return np.array([fn(*[self.take(i)[col][0] for col in labels]) for i in range(self.num_rows)])

############
# Mutation #
Expand Down Expand Up @@ -681,6 +681,7 @@ def relabel(self, column_label, new_label):

def copy(self):
"""Return a copy of a Table."""
# TODO(denero) Shallow copy by default with an option for deep copy
table = Table()
for label in self.labels:
self._add_column_and_format(table, label, np.copy(self[label]))
Expand Down Expand Up @@ -756,7 +757,7 @@ def sort(self, column_or_label, descending=False, distinct=False):
row_numbers = np.array(row_numbers[::-1])
return self.take(row_numbers)

def group(self, column_or_label, collect=lambda s: s):
def group(self, column_or_label, collect=len):
"""Group rows by unique values in column_label, aggregating values.
collect -- an optional function applied to the values for each group.
Expand Down Expand Up @@ -789,14 +790,15 @@ def group(self, column_or_label, collect=lambda s: s):
grouped.move_to_start(column_label)
return grouped

def groups(self, labels, collect=lambda s: s):
def groups(self, labels, collect=len):
"""Group rows by multiple columns, aggregating values."""
collect = _zero_on_type_error(collect)
columns = []
labels = [self._as_label(label) for label in labels]
for label in labels:
assert label in self.labels
columns.append(self._get_column(label))
grouped = self.group(list(zip(*columns)))
grouped = self.group(list(zip(*columns)), lambda s: s)
grouped._columns.popitem(last=False) # Discard the column of tuples

# Flatten grouping values and move them to front
Expand Down Expand Up @@ -957,6 +959,8 @@ def _get_column(self, column_or_label):
c = column_or_label
if isinstance(c, collections.Hashable) and c in self.labels:
return self[c]
elif isinstance(c, numbers.Integral):
return self[c]
elif isinstance(c, str):
assert c in self.labels, 'label "{}" not in labels {}'.format(c, self.labels)
else:
Expand Down Expand Up @@ -2007,7 +2011,6 @@ def _as_labels(column_label_or_labels):
else:
return column_label_or_labels


def _assert_same(values):
"""Assert that all values are identical and return the unique value."""
assert len(values) > 0
Expand Down
28 changes: 21 additions & 7 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_basic_rows(t):


def test_select(t):
test = t.select(['points', 'count']).cumsum()
test = t.select(['points', 1]).cumsum()
assert_equal(test, """
points | count
1 | 9
Expand Down Expand Up @@ -220,6 +220,12 @@ def test_where(t):
b | 3 | 2
c | 3 | 2
""")
test = t.where(2, 2)
assert_equal(test, """
letter | count | points
b | 3 | 2
c | 3 | 2
""")


def test_where_conditions(t):
Expand All @@ -241,6 +247,14 @@ def test_sort(t):
c | 3 | 2 | 6
z | 1 | 10 | 10
""")
test = t.sort(3)
assert_equal(test, """
letter | count | points | totals
b | 3 | 2 | 6
c | 3 | 2 | 6
a | 9 | 1 | 9
z | 1 | 10 | 10
""")


def test_sort_args(t):
Expand All @@ -267,10 +281,10 @@ def test_sort_syntax(t):
def test_group(t):
test = t.group('points')
assert_equal(test, """
points | letter | count | totals
1 | ['a'] | [9] | [9]
2 | ['b' 'c'] | [3 3] | [6 6]
10 | ['z'] | [1] | [10]
points | letter len | count len | totals len
1 | 1 | 1 | 1
2 | 2 | 2 | 2
10 | 1 | 1 | 1
""")


Expand All @@ -288,7 +302,7 @@ def test_groups(t):
t = t.copy()
t.append(('e', 12, 1, 12))
t['early'] = t['letter'] < 'd'
test = t.groups(['points', 'early'])
test = t.groups(['points', 'early'], lambda s: s)
assert_equal(test, """
points | early | letter | count | totals
1 | False | ['e'] | [12] | [12]
Expand Down Expand Up @@ -693,7 +707,7 @@ def test_group_by_tuples():
(1, 2, 2, 10) | 3
(1, 2, 2, 10) | 1
""")
table = t.group('tuples')
table = t.group('tuples', lambda s: s)
assert_equal(table, """
tuples | ints
(1, 2, 2, 10) | [3 1]
Expand Down

0 comments on commit b3107e5

Please sign in to comment.