diff --git a/README.md b/README.md index f90aa88ac..1743cc6d8 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,10 @@ pip install datascience This project adheres to [Semantic Versioning](http://semver.org/). +### v0.11.0 + +- Added `join` for multiple columns. + ### v0.10.3 - Fix `OrderedDict` bug in `Table.hist` diff --git a/datascience/tables.py b/datascience/tables.py index c71efe273..da386ff6d 100644 --- a/datascience/tables.py +++ b/datascience/tables.py @@ -1179,13 +1179,13 @@ def join(self, column_label, other, other_label=None): rows for all values of a column that appear in both tables. Args: - ``column_label`` (``str``): label of column in self that is used to + ``column_label``: label of column or array of labels in self that is used to join rows of ``other``. ``other``: Table object to join with self on matching values of ``column_label``. Kwargs: - ``other_label`` (``str``): default None, assumes ``column_label``. + ``other_label``: default None, assumes ``column_label``. Otherwise in ``other`` used to join rows. Returns: @@ -1238,12 +1238,31 @@ def join(self, column_label, other, other_label=None): 3 | 2 | 4 3 | 2 | 5 1 | 10 | 6 + >>> table.join(['a', 'b'], table2, ['a', 'd']) # joining on multiple columns + a | b | c | e + 1 | 10 | 6 | 6 + 9 | 1 | 3 | 3 """ if self.num_rows == 0 or other.num_rows == 0: return None if not other_label: other_label = column_label + # checking to see if joining on multiple columns + if _is_non_string_iterable(column_label): + # then we are going to be joining multiple labels + return self._multiple_join(column_label, other, other_label) + + # original single column join + return self._join(column_label, other, other_label) + + def _join(self, column_label, other, other_label=None): + """joins when COLUMN_LABEL is a string""" + if self.num_rows == 0 or other.num_rows == 0: + return None + if not other_label: + other_label = column_label + self_rows = self.index_by(column_label) other_rows = other.index_by(other_label) @@ -1273,6 +1292,43 @@ def join(self, column_label, other, other_label=None): return joined.move_to_start(column_label).sort(column_label) + def _multiple_join(self, column_label, other, other_label=None): + """joins when column_label is a non-string iterable""" + assert len(column_label) == len(other_label), 'unequal number of columns' + + self_rows = self._multi_index(column_label) + other_rows = other._multi_index(other_label) + + # Gather joined rows from self_rows that have join values in other_rows + joined_rows = [] + for v, rows in self_rows.items(): + if v in other_rows: + joined_rows += [row + o for row in rows for o in other_rows[v]] + if not joined_rows: + return None + + # Build joined table + self_labels = list(self.labels) + other_labels = [self._unused_label(s) for s in other.labels] + other_labels_map = dict(zip(other.labels, other_labels)) + joined = type(self)(self_labels + other_labels).with_rows(joined_rows) + + # Copy formats from both tables + joined._formats.update(self._formats) + for label in other._formats: + joined._formats[other_labels_map[label]] = other._formats[label] + + # Remove redundant column, but perhaps save its formatting + for duplicate in other_label: + del joined[other_labels_map[duplicate]] + for duplicate in other_label: + if duplicate not in self._formats and duplicate in other._formats: + joined._formats[duplicate] = other._formats[duplicate] + + for col in column_label[::-1]: + joined = joined.move_to_start(col).sort(col) + + return joined def stats(self, ops=(min, max, np.median, sum)): """Compute statistics for each column and place them in a table.""" @@ -1878,6 +1934,15 @@ def index_by(self, column_or_label): index.setdefault(key, []).append(row) return index + def _multi_index(self, columns_or_labels): + """Returns a dict keyed by a tuple of the values that correspond to + the selected COLUMNS_OR_LABELS, with values corresponding to """ + columns = [self._get_column(col) for col in columns_or_labels] + index = {} + for key, row in zip(zip(*columns), self.rows): + index.setdefault(key, []).append(row) + return index + def to_df(self): """Convert the table to a Pandas DataFrame.""" return pandas.DataFrame(self._columns) diff --git a/tests/test_tables.py b/tests/test_tables.py index 44764ad3f..7bc574bf1 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -40,6 +40,15 @@ def table3(): 'letter', ['x', 'y', 'z'], ]) +@pytest.fixture(scope='function') +def table4(): + """Setup fourth table; three overlapping columns with table.""" + return Table().with_columns([ + 'letter', ['a', 'b', 'c', '8', 'a'], + 'count', [9, 3, 2, 0, 9], + 'different label', [1, 4, 2, 1, 1], + 'name', ['Gamma', 'Delta', 'Epsilon', 'Alpha', 'Beta'] + ]) @pytest.fixture(scope='function') def numbers_table(): @@ -1105,6 +1114,71 @@ def test_join_with_two_labels_one_format(table): z | 1 | 10 | 1 | $10.00 """) +def test_join_one_list_with_one_label(table, table4): + table['totals'] = table['points'] * table['count'] + test = table.join(['letter'], table4.drop('count', 'different label')) + assert_equal(test, """ + letter | count | points | totals | name + a | 9 | 1 | 9 | Gamma + a | 9 | 1 | 9 | Beta + b | 3 | 2 | 6 | Delta + c | 3 | 2 | 6 | Epsilon + """) + +def test_join_two_lists_same_label(table, table4): + table['totals'] = table['points'] * table['count'] + test = table.join(['letter'], table4.drop('count', 'different label'), ['letter']) + assert_equal(test, """ + letter | count | points | totals | name + a | 9 | 1 | 9 | Gamma + a | 9 | 1 | 9 | Beta + b | 3 | 2 | 6 | Delta + c | 3 | 2 | 6 | Epsilon + """) + +def test_join_two_lists_different_labels(table, table4): + # also checks for multiple matches on one side + table['totals'] = table['points'] * table['count'] + test = table.join(['points'], table4.drop('letter', 'count'), ['different label']) + assert_equal(test, """ + points | letter | count | totals | name + 1 | a | 9 | 9 | Gamma + 1 | a | 9 | 9 | Alpha + 1 | a | 9 | 9 | Beta + 2 | b | 3 | 6 | Epsilon + 2 | c | 3 | 6 | Epsilon + """) + +def test_join_two_lists_2_columns(table, table4): + table['totals'] = table['points'] * table['count'] + test = table.join(['letter', 'points'], table4, ['letter', 'different label']) + assert_equal(test, """ + letter | points | count | totals | count_2 | name + a | 1 | 9 | 9 | 9 | Gamma + a | 1 | 9 | 9 | 9 | Beta + c | 2 | 3 | 6 | 2 | Epsilon + """) + +def test_join_two_lists_3_columns(table, table4): + table['totals'] = table['points'] * table['count'] + test = table.join(['letter', 'count', 'points'], table4, ['letter', 'count', 'different label']) + assert_equal(test, """ + letter | count | points | totals | name + a | 9 | 1 | 9 | Gamma + a | 9 | 1 | 9 | Beta + """) + +def test_join_conflicting_column_names(table, table4): + table['totals'] = table['points'] * table['count'] + test = table.join(['letter'], table4) + assert_equal(test, """ + letter | count | points | totals | count_2 | different label | name + a | 9 | 1 | 9 | 9 | 1 | Gamma + a | 9 | 1 | 9 | 9 | 1 | Beta + b | 3 | 2 | 6 | 3 | 4 | Delta + c | 3 | 2 | 6 | 2 | 2 | Epsilon + """) + def test_percentile(numbers_table): assert_equal(numbers_table.percentile(76), """ count | points