Skip to content

Commit

Permalink
Merge pull request #546 from GavinHuttley/develop
Browse files Browse the repository at this point in the history
MAINT: tidy of Table.joined
  • Loading branch information
GavinHuttley committed Feb 24, 2020
2 parents 078f7a8 + 0a3eb2e commit e15b8a9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 33 deletions.
72 changes: 40 additions & 32 deletions src/cogent3/util/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pickle
import warnings

from collections import defaultdict
from collections.abc import Callable
from xml.sax.saxutils import escape

Expand Down Expand Up @@ -1127,34 +1128,38 @@ def joined(
if self.title == other_title:
raise RuntimeError("Cannot join if a table.Title's are equal")

columns_self = [columns_self, [columns_self]][type(columns_self) == str]
columns_other = [columns_other, [columns_other]][type(columns_other) == str]
columns_self = [columns_self] if isinstance(columns_self, str) else columns_self
columns_other = (
[columns_other] if isinstance(columns_other, str) else columns_other
)
if not inner_join:
assert columns_self is None and columns_other is None, (
"Cannot " "specify column indices for an outer join"
)
assert (
columns_self is None and columns_other is None
), "Cannot specify column indices for an outer join"
columns_self = []
columns_other = []

if columns_self is None and columns_other is None:
# we do the natural inner join
columns_self = []
columns_other = []
for col_head in self.header:
shared = set(self.header) & set(other_table.header)
for col_head in shared:
if col_head in other_table.header:
columns_self.append(self.header.index(col_head))
columns_other.append(other_table.header.index(col_head))
elif columns_self is None or columns_other is None:
# the same column labels will be used for both tables
columns_self = columns_self or columns_other
columns_other = columns_self or columns_other
elif len(columns_self) != len(columns_other):

if len(columns_self) != len(columns_other):
raise RuntimeError(
"Error during table join: key columns have " "different dimensions!"
"Error during table join: key columns have different dimensions!"
)

# create new 2d list for the output
joined_table = []
joined_data = []

# resolve column indices from header, if necessary
columns_self_indices = []
Expand All @@ -1170,41 +1175,44 @@ def joined(
columns_other_indices.append(col)
else:
columns_other_indices.append(other_table.header.index(col))

# create a mask of which columns of the other_table will end up in the
# output
output_mask_other = []
for col in range(0, len(other_table.header)):
if not (col in columns_other_indices):
output_mask_other.append(col)
output_mask_other = [
i
for i in range(0, len(other_table.header))
if i not in columns_other_indices
]
new_header = self.header + [
other_title + "_" + other_table.header[c] for c in output_mask_other
]

# use a dictionary for the key lookup
# key dictionary for other_table.
# key is a tuple made from specified columns; data is the row index
# for lookup...
key_lookup = {}
row_index = 0
for row in other_table:
key_lookup = defaultdict(list)
for row_index, row in enumerate(other_table):
# insert new entry for each row
key = tuple([row[col] for col in columns_other_indices])
if key in key_lookup:
key_lookup[key].append(row_index)
else:
key_lookup[key] = [row_index]
row_index = row_index + 1
key_lookup[key].append(row_index)

for this_row in self:
other_row_indices = []
self_row_indices = []
for row_index, this_row in enumerate(self):
# assemble key for query of other_table
key = tuple([this_row[col] for col in columns_self_indices])
if key in key_lookup:
for output_row_index in key_lookup[key]:
other_row = [
other_table[output_row_index, c] for c in output_mask_other
]
joined_table.append(list(this_row) + other_row)
if key not in key_lookup:
continue

new_header = self.header + [
other_title + "_" + other_table.header[c] for c in output_mask_other
]
return Table(header=new_header, rows=joined_table, **kwargs)
self_row_indices.extend([row_index] * len(key_lookup[key]))
other_row_indices.extend(key_lookup[key])

self_data = self.array[self_row_indices]
other_data = other_table.array[numpy.ix_(other_row_indices, output_mask_other)]
joined_data = numpy.concatenate([self_data, other_data], axis=1)

return Table(header=new_header, rows=joined_data, **kwargs)

def summed(self, indices=None, col_sum=True, strict=True, **kwargs):
"""returns the sum of numerical values for column(s)/row(s)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_util/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def test_joined(self):
t2 = Table(header=self.t2_header, rows=self.t2_rows)
t3 = Table(header=self.t3_header, rows=self.t3_rows)

# inner join with defaults
self.assertEqual(t2.joined(t3).shape[0], 0)

# inner join test
self.assertEqual(
t2.joined(t3, columns_self="foo", columns_other="foo").shape[0], 4
Expand All @@ -189,7 +192,6 @@ def test_joined(self):
self.assertEqual(
t2.joined(t3, columns_self="foo", columns_other="foo").shape[1], 5
)

# non-inner join test (cartesian product of rows)
self.assertEqual(
t2.joined(t3, inner_join=False).shape[0], t2.shape[0] * t3.shape[0]
Expand Down

0 comments on commit e15b8a9

Please sign in to comment.