Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 98 additions & 17 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2332,41 +2332,61 @@ def merge(
right_join_ids: typing.Sequence[str],
sort: bool,
suffixes: tuple[str, str] = ("_x", "_y"),
left_index: bool = False,
right_index: bool = False,
) -> Block:
conditions = tuple(
(lid, rid) for lid, rid in zip(left_join_ids, right_join_ids)
)
joined_expr, (get_column_left, get_column_right) = self.expr.relational_join(
other.expr, type=how, conditions=conditions
)
result_columns = []
matching_join_labels = []

left_post_join_ids = tuple(get_column_left[id] for id in left_join_ids)
right_post_join_ids = tuple(get_column_right[id] for id in right_join_ids)

joined_expr, coalesced_ids = coalesce_columns(
joined_expr, left_post_join_ids, right_post_join_ids, how=how, drop=False
)
if left_index or right_index:
# For some reason pandas coalesces two joining columns if one side is an index.
joined_expr, resolved_join_ids = coalesce_columns(
joined_expr, left_post_join_ids, right_post_join_ids
)
else:
joined_expr, resolved_join_ids = resolve_col_join_ids( # type: ignore
joined_expr,
left_post_join_ids,
right_post_join_ids,
how=how,
drop=False,
)

result_columns = []
matching_join_labels = []

# Select left value columns
for col_id in self.value_columns:
if col_id in left_join_ids:
key_part = left_join_ids.index(col_id)
matching_right_id = right_join_ids[key_part]
if (
self.col_id_to_label[col_id]
right_index
or self.col_id_to_label[col_id]
== other.col_id_to_label[matching_right_id]
):
matching_join_labels.append(self.col_id_to_label[col_id])
result_columns.append(coalesced_ids[key_part])
result_columns.append(resolved_join_ids[key_part])
else:
result_columns.append(get_column_left[col_id])
else:
result_columns.append(get_column_left[col_id])

# Select right value columns
for col_id in other.value_columns:
if col_id in right_join_ids:
if other.col_id_to_label[col_id] in matching_join_labels:
pass
elif left_index:
key_part = right_join_ids.index(col_id)
result_columns.append(resolved_join_ids[key_part])
else:
result_columns.append(get_column_right[col_id])
else:
Expand All @@ -2377,11 +2397,22 @@ def merge(
joined_expr = joined_expr.order_by(
[
ordering.OrderingExpression(ex.deref(col_id))
for col_id in coalesced_ids
for col_id in resolved_join_ids
],
)

joined_expr = joined_expr.select_columns(result_columns)
left_idx_id_post_join = [get_column_left[id] for id in self.index_columns]
right_idx_id_post_join = [get_column_right[id] for id in other.index_columns]
index_cols = _resolve_index_col(
left_idx_id_post_join,
right_idx_id_post_join,
resolved_join_ids,
left_index,
right_index,
how,
)

joined_expr = joined_expr.select_columns(result_columns + index_cols)
labels = utils.merge_column_labels(
self.column_labels,
other.column_labels,
Expand All @@ -2400,13 +2431,13 @@ def merge(
or other.index.is_null
or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL
):
expr = joined_expr
index_columns = []
return Block(joined_expr, index_columns=[], column_labels=labels)
elif index_cols:
return Block(joined_expr, index_columns=index_cols, column_labels=labels)
else:
expr, offset_index_id = joined_expr.promote_offsets()
index_columns = [offset_index_id]

return Block(expr, index_columns=index_columns, column_labels=labels)
return Block(expr, index_columns=index_columns, column_labels=labels)

def _align_both_axes(
self, other: Block, how: str
Expand Down Expand Up @@ -3115,7 +3146,7 @@ def join_mono_indexed(
left_index = get_column_left[left.index_columns[0]]
right_index = get_column_right[right.index_columns[0]]
# Drop original indices from each side. and used the coalesced combination generated by the join.
combined_expr, coalesced_join_cols = coalesce_columns(
combined_expr, coalesced_join_cols = resolve_col_join_ids(
combined_expr, [left_index], [right_index], how=how
)
if sort:
Expand Down Expand Up @@ -3180,7 +3211,7 @@ def join_multi_indexed(
left_ids_post_join = [get_column_left[id] for id in left_join_ids]
right_ids_post_join = [get_column_right[id] for id in right_join_ids]
# Drop original indices from each side. and used the coalesced combination generated by the join.
combined_expr, coalesced_join_cols = coalesce_columns(
combined_expr, coalesced_join_cols = resolve_col_join_ids(
combined_expr, left_ids_post_join, right_ids_post_join, how=how
)
if sort:
Expand Down Expand Up @@ -3223,13 +3254,17 @@ def resolve_label_id(label: Label) -> str:


# TODO: Rewrite just to return expressions
def coalesce_columns(
def resolve_col_join_ids(
expr: core.ArrayValue,
left_ids: typing.Sequence[str],
right_ids: typing.Sequence[str],
how: str,
drop: bool = True,
) -> Tuple[core.ArrayValue, Sequence[str]]:
"""
Collapses and selects the joining column IDs, with the assumption that
the ids are all belong to value columns.
"""
result_ids = []
for left_id, right_id in zip(left_ids, right_ids):
if how == "left" or how == "inner" or how == "cross":
Expand All @@ -3241,7 +3276,6 @@ def coalesce_columns(
if drop:
expr = expr.drop_columns([left_id])
elif how == "outer":
coalesced_id = guid.generate_guid()
expr, coalesced_id = expr.project_to_id(
ops.coalesce_op.as_expr(left_id, right_id)
)
Expand All @@ -3253,6 +3287,21 @@ def coalesce_columns(
return expr, result_ids


def coalesce_columns(
expr: core.ArrayValue,
left_ids: typing.Sequence[str],
right_ids: typing.Sequence[str],
) -> tuple[core.ArrayValue, list[str]]:
result_ids = []
for left_id, right_id in zip(left_ids, right_ids):
expr, coalesced_id = expr.project_to_id(
ops.coalesce_op.as_expr(left_id, right_id)
)
result_ids.append(coalesced_id)

return expr, result_ids


def _cast_index(block: Block, dtypes: typing.Sequence[bigframes.dtypes.Dtype]):
original_block = block
result_ids = []
Expand Down Expand Up @@ -3468,3 +3517,35 @@ def _pd_index_to_array_value(
rows.append(row)

return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session)


def _resolve_index_col(
left_index_cols: list[str],
right_index_cols: list[str],
resolved_join_ids: list[str],
left_index: bool,
right_index: bool,
how: typing.Literal[
"inner",
"left",
"outer",
"right",
"cross",
],
) -> list[str]:
if left_index and right_index:
if how == "inner" or how == "left":
return left_index_cols
if how == "right":
return right_index_cols
if how == "outer":
return resolved_join_ids
else:
return []
elif left_index and not right_index:
return right_index_cols
elif right_index and not left_index:
return left_index_cols
else:
# Joining with value columns only. Existing indices will be discarded.
return []
Loading