Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ pip install datascience

This project adheres to [Semantic Versioning](http://semver.org/).

### v0.11.0

- Added `join` for multiple columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to update this before merging.

### v0.10.3

- Fix `OrderedDict` bug in `Table.hist`
Expand Down
69 changes: 67 additions & 2 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,13 +1179,13 @@ def join(self, column_label, other, other_label=None):
rows for all values of a column that appear in both tables.

Args:
``column_label`` (``str``): label of column in self that is used to
``column_label``: label of column or array of labels in self that is used to
join rows of ``other``.
``other``: Table object to join with self on matching values of
``column_label``.

Kwargs:
``other_label`` (``str``): default None, assumes ``column_label``.
``other_label``: default None, assumes ``column_label``.
Otherwise in ``other`` used to join rows.

Returns:
Expand Down Expand Up @@ -1238,12 +1238,31 @@ def join(self, column_label, other, other_label=None):
3 | 2 | 4
3 | 2 | 5
1 | 10 | 6
>>> table.join(['a', 'b'], table2, ['a', 'd']) # joining on multiple columns
a | b | c | e
1 | 10 | 6 | 6
9 | 1 | 3 | 3
"""
if self.num_rows == 0 or other.num_rows == 0:
return None
if not other_label:
other_label = column_label

# checking to see if joining on multiple columns
if _is_non_string_iterable(column_label):
# then we are going to be joining multiple labels
return self._multiple_join(column_label, other, other_label)

# original single column join
return self._join(column_label, other, other_label)

def _join(self, column_label, other, other_label=None):
"""joins when COLUMN_LABEL is a string"""
if self.num_rows == 0 or other.num_rows == 0:
return None
if not other_label:
other_label = column_label

self_rows = self.index_by(column_label)
other_rows = other.index_by(other_label)

Expand Down Expand Up @@ -1273,6 +1292,43 @@ def join(self, column_label, other, other_label=None):

return joined.move_to_start(column_label).sort(column_label)

def _multiple_join(self, column_label, other, other_label=None):
"""joins when column_label is a non-string iterable"""
assert len(column_label) == len(other_label), 'unequal number of columns'

self_rows = self._multi_index(column_label)
other_rows = other._multi_index(other_label)

# Gather joined rows from self_rows that have join values in other_rows
joined_rows = []
for v, rows in self_rows.items():
if v in other_rows:
joined_rows += [row + o for row in rows for o in other_rows[v]]
if not joined_rows:
return None

# Build joined table
self_labels = list(self.labels)
other_labels = [self._unused_label(s) for s in other.labels]
other_labels_map = dict(zip(other.labels, other_labels))
joined = type(self)(self_labels + other_labels).with_rows(joined_rows)

# Copy formats from both tables
joined._formats.update(self._formats)
for label in other._formats:
joined._formats[other_labels_map[label]] = other._formats[label]

# Remove redundant column, but perhaps save its formatting
for duplicate in other_label:
del joined[other_labels_map[duplicate]]
for duplicate in other_label:
if duplicate not in self._formats and duplicate in other._formats:
joined._formats[duplicate] = other._formats[duplicate]

for col in column_label[::-1]:
joined = joined.move_to_start(col).sort(col)

return joined

def stats(self, ops=(min, max, np.median, sum)):
"""Compute statistics for each column and place them in a table."""
Expand Down Expand Up @@ -1878,6 +1934,15 @@ def index_by(self, column_or_label):
index.setdefault(key, []).append(row)
return index

def _multi_index(self, columns_or_labels):
"""Returns a dict keyed by a tuple of the values that correspond to
the selected COLUMNS_OR_LABELS, with values corresponding to """
columns = [self._get_column(col) for col in columns_or_labels]
index = {}
for key, row in zip(zip(*columns), self.rows):
index.setdefault(key, []).append(row)
return index

def to_df(self):
"""Convert the table to a Pandas DataFrame."""
return pandas.DataFrame(self._columns)
Expand Down
74 changes: 74 additions & 0 deletions tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def table3():
'letter', ['x', 'y', 'z'],
])

@pytest.fixture(scope='function')
def table4():
"""Setup fourth table; three overlapping columns with table."""
return Table().with_columns([
'letter', ['a', 'b', 'c', '8', 'a'],
'count', [9, 3, 2, 0, 9],
'different label', [1, 4, 2, 1, 1],
'name', ['Gamma', 'Delta', 'Epsilon', 'Alpha', 'Beta']
])

@pytest.fixture(scope='function')
def numbers_table():
Expand Down Expand Up @@ -1105,6 +1114,71 @@ def test_join_with_two_labels_one_format(table):
z | 1 | 10 | 1 | $10.00
""")

def test_join_one_list_with_one_label(table, table4):
table['totals'] = table['points'] * table['count']
test = table.join(['letter'], table4.drop('count', 'different label'))
assert_equal(test, """
letter | count | points | totals | name
a | 9 | 1 | 9 | Gamma
a | 9 | 1 | 9 | Beta
b | 3 | 2 | 6 | Delta
c | 3 | 2 | 6 | Epsilon
""")

def test_join_two_lists_same_label(table, table4):
table['totals'] = table['points'] * table['count']
test = table.join(['letter'], table4.drop('count', 'different label'), ['letter'])
assert_equal(test, """
letter | count | points | totals | name
a | 9 | 1 | 9 | Gamma
a | 9 | 1 | 9 | Beta
b | 3 | 2 | 6 | Delta
c | 3 | 2 | 6 | Epsilon
""")

def test_join_two_lists_different_labels(table, table4):
# also checks for multiple matches on one side
table['totals'] = table['points'] * table['count']
test = table.join(['points'], table4.drop('letter', 'count'), ['different label'])
assert_equal(test, """
points | letter | count | totals | name
1 | a | 9 | 9 | Gamma
1 | a | 9 | 9 | Alpha
1 | a | 9 | 9 | Beta
2 | b | 3 | 6 | Epsilon
2 | c | 3 | 6 | Epsilon
""")

def test_join_two_lists_2_columns(table, table4):
table['totals'] = table['points'] * table['count']
test = table.join(['letter', 'points'], table4, ['letter', 'different label'])
assert_equal(test, """
letter | points | count | totals | count_2 | name
a | 1 | 9 | 9 | 9 | Gamma
a | 1 | 9 | 9 | 9 | Beta
c | 2 | 3 | 6 | 2 | Epsilon
""")

def test_join_two_lists_3_columns(table, table4):
table['totals'] = table['points'] * table['count']
test = table.join(['letter', 'count', 'points'], table4, ['letter', 'count', 'different label'])
assert_equal(test, """
letter | count | points | totals | name
a | 9 | 1 | 9 | Gamma
a | 9 | 1 | 9 | Beta
""")

def test_join_conflicting_column_names(table, table4):
table['totals'] = table['points'] * table['count']
test = table.join(['letter'], table4)
assert_equal(test, """
letter | count | points | totals | count_2 | different label | name
a | 9 | 1 | 9 | 9 | 1 | Gamma
a | 9 | 1 | 9 | 9 | 1 | Beta
b | 3 | 2 | 6 | 3 | 4 | Delta
c | 3 | 2 | 6 | 2 | 2 | Epsilon
""")

def test_percentile(numbers_table):
assert_equal(numbers_table.percentile(76), """
count | points
Expand Down