diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index ed10ef78..c205f92a 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence from typing import TYPE_CHECKING, Any from orcapod.channels import ReadableChannel, WritableChannel @@ -224,14 +224,19 @@ def _compute_system_tag_suffixes( n_char = self.orcapod_config.system_tag_hash_n_char hex_strings = [h.to_hex() for h in input_pipeline_hashes] - # Canonical order: sorted by full hex (same as order_input_streams) - sorted_hexes = sorted(hex_strings) + # Canonical order: sorted by full hex (same as order_input_streams). + # Use the original index as a tiebreaker so that inputs with + # identical pipeline hashes still receive distinct positions + # (matching Python's stable sort used by order_input_streams). + ranked = sorted(range(len(hex_strings)), key=lambda i: hex_strings[i]) + canon_position = [0] * len(hex_strings) + for canon_idx, orig_idx in enumerate(ranked): + canon_position[orig_idx] = canon_idx suffixes: list[str] = [] - for orig_idx, hex_str in enumerate(hex_strings): - canon_idx = sorted_hexes.index(hex_str) + for orig_idx in range(len(hex_strings)): truncated = input_pipeline_hashes[orig_idx].to_hex(n_char) - suffixes.append(f"{truncated}:{canon_idx}") + suffixes.append(f"{truncated}:{canon_position[orig_idx]}") return suffixes async def async_execute( @@ -241,17 +246,19 @@ async def async_execute( *, input_pipeline_hashes: Sequence[ContentHash] | None = None, ) -> None: - """Async join with streaming symmetric hash join for two inputs. + """Async streaming join with pairwise iterative semantics. Single input: streams through directly without any buffering. - Two inputs: symmetric hash join — each arriving row is - immediately probed against the opposite side's buffer, emitting - matches as soon as found. System-tag columns are correctly - renamed using the ``input_pipeline_hashes``. + Two inputs: binary symmetric hash join — each arriving row is + probed against the opposite side's buffer, emitting matches as + soon as found. - Three or more inputs: collects all inputs concurrently, then - delegates to ``static_process`` for the Polars N-way join. + Three or more inputs: staggered pairwise binary joins in + canonical order — ``join(join(x, y), z)`` — matching + ``static_process``'s iterative accumulation. Each binary join + uses the per-pair intersection of tag keys, so partially + overlapping tag schemas are handled correctly. Args: inputs: Readable channels, one per upstream. @@ -266,53 +273,162 @@ async def async_execute( await output.send((tag, packet)) return - # TODO: carefully revisit the logic behind system tag handling - if len(inputs) == 2: - suffixes = ( - self._compute_system_tag_suffixes(input_pipeline_hashes) - if input_pipeline_hashes is not None - else ["0", "1"] + n = len(inputs) + if input_pipeline_hashes is not None and len(input_pipeline_hashes) != n: + raise ValueError( + f"input_pipeline_hashes length ({len(input_pipeline_hashes)}) " + f"must match inputs length ({n})" ) - await self._symmetric_hash_join(inputs[0], inputs[1], output, suffixes) - return + suffixes = ( + self._compute_system_tag_suffixes(input_pipeline_hashes) + if input_pipeline_hashes is not None + else [str(i) for i in range(n)] + ) + await self._streaming_join(inputs, output, suffixes) + finally: + await output.close() - # N > 2: concurrent collection + static_process - all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) + async def _streaming_join( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + suffixes: list[str], + ) -> None: + """Dispatch between binary join (N=2) and staggered chain (N>=3). - # Guard against empty inputs — join with an empty side is empty - if any(len(rows) == 0 for rows in all_rows): - return + Args: + inputs: Readable channels, one per upstream. + output: Output channel for matched rows. + suffixes: Per-input system-tag suffixes (positional). + """ + n = len(inputs) + block_sep = constants.BLOCK_SEPARATOR - streams = [self._materialize_to_stream(rows) for rows in all_rows] - result = self.static_process(*streams) - for tag, packet in result.iter_packets(): - await output.send((tag, packet)) - finally: - await output.close() + if n == 2: + + def merge_fn( + lt: TagProtocol, + lp: PacketProtocol, + rt: TagProtocol, + rp: PacketProtocol, + ) -> tuple[TagProtocol, PacketProtocol]: + return self._merge_pair_rename(lt, lp, rt, rp, suffixes, block_sep) + + await self._binary_streaming_join(inputs[0], inputs[1], output, merge_fn) + return - async def _symmetric_hash_join( + # N >= 3: staggered pairwise joins matching static_process + await self._staggered_join(inputs, output, suffixes) + + async def _staggered_join( self, - left_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], - right_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], suffixes: list[str], ) -> None: - """Symmetric hash join for two inputs. + """Staggered pairwise binary joins: ``join(join(x, y), z)``. + + Matches ``static_process``'s iterative pairwise join semantics. + Each input's system tags are pre-renamed, then binary joins are + chained in canonical order. Per-pair join keys are computed + naturally by each binary join (intersection of its two inputs' + tag keys), so partially overlapping tag schemas produce the + same results as the sync path. + + Intermediate results flow through channels, so downstream joins + can start work as soon as earlier joins emit matches — the + pipeline is fully streaming end-to-end. + + Args: + inputs: Readable channels, one per upstream. + output: Output channel for matched rows. + suffixes: Per-input system-tag suffixes (positional). + """ + from orcapod.channels import Channel + + n = len(inputs) + block_sep = constants.BLOCK_SEPARATOR + sys_prefix = constants.SYSTEM_TAG_PREFIX + + # Canonical order: sorted by canonical position encoded in suffixes. + # Suffixes are "hash:position" when pipeline hashes are provided, + # or plain "0", "1", ... otherwise. + def _canon_pos(i: int) -> int: + parts = suffixes[i].rsplit(":", 1) + return int(parts[1]) if len(parts) == 2 else int(parts[0]) + + canon_order = sorted(range(n), key=_canon_pos) + + async with asyncio.TaskGroup() as tg: + # Pre-rename system tags for each input so binary joins + # can pass them through without modification + renamed_readers: list[ + ReadableChannel[tuple[TagProtocol, PacketProtocol]] + ] = [] + for orig_idx in canon_order: + ch: Channel[tuple[TagProtocol, PacketProtocol]] = Channel( + buffer_size=64 + ) + tg.create_task( + self._rename_sys_tags( + inputs[orig_idx], + ch.writer, + suffixes[orig_idx], + block_sep, + sys_prefix, + ) + ) + renamed_readers.append(ch.reader) + + # Chain: renamed[0] ⋈ renamed[1] → intermediate ⋈ renamed[2] → … → output + current_reader = renamed_readers[0] + for i in range(1, len(renamed_readers)): + is_last = i == len(renamed_readers) - 1 + if is_last: + target_writer = output + else: + intermediate: Channel[ + tuple[TagProtocol, PacketProtocol] + ] = Channel(buffer_size=64) + target_writer = intermediate.writer + + tg.create_task( + self._binary_streaming_join( + current_reader, + renamed_readers[i], + target_writer, + self._merge_pair_passthrough, + ) + ) + + if not is_last: + current_reader = intermediate.reader + + async def _binary_streaming_join( + self, + left: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + right: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + merge_fn: Callable[ + [TagProtocol, PacketProtocol, TagProtocol, PacketProtocol], + tuple[TagProtocol, PacketProtocol], + ], + ) -> None: + """Binary symmetric hash join. Both sides are read concurrently via a merged bounded queue. - Each arriving row is added to its side's index and immediately - probed against the opposite side. Matched rows are emitted to - ``output`` as soon as found, so downstream consumers can begin - work before either input is fully consumed. + Each arriving row is indexed and immediately probed against the + opposite side. Matched rows are emitted via ``merge_fn`` as + soon as found, so downstream can begin work before either input + is fully consumed. Args: - left_ch: Left input channel. - right_ch: Right input channel. + left: Left input channel. + right: Right input channel. output: Output channel for matched rows. - suffixes: Per-input system-tag suffixes (positional), - computed from pipeline hashes and canonical ordering. + merge_fn: Callable(left_tag, left_pkt, right_tag, right_pkt) + that produces the merged (Tag, Packet) pair. """ - # Bounded queue preserves backpressure — producers block when full. _SENTINEL = object() queue: asyncio.Queue = asyncio.Queue(maxsize=64) @@ -324,106 +440,110 @@ async def _drain( await queue.put((side, item)) await queue.put((side, _SENTINEL)) - block_sep = constants.BLOCK_SEPARATOR + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(_drain(left, 0)) + tg.create_task(_drain(right, 1)) - async with asyncio.TaskGroup() as tg: - tg.create_task(_drain(left_ch, 0)) - tg.create_task(_drain(right_ch, 1)) + buffers: list[list[tuple[TagProtocol, PacketProtocol]]] = [ + [], + [], + ] + indexes: list[dict[tuple, list[int]]] = [{}, {}] - # buffers[i] holds all rows seen so far from input i - buffers: list[list[tuple[TagProtocol, PacketProtocol]]] = [[], []] - # indexes[i] maps shared-key tuple → list of indices into buffers[i] - indexes: list[dict[tuple, list[int]]] = [{}, {}] + shared_keys: tuple[str, ...] | None = None + needs_reindex = False + closed_count = 0 - shared_keys: tuple[str, ...] | None = None - needs_reindex = False - closed_count = 0 + while closed_count < 2: + side, item = await queue.get() - while closed_count < 2: - side, item = await queue.get() + if item is _SENTINEL: + closed_count += 1 + continue - if item is _SENTINEL: - closed_count += 1 - continue + tag, pkt = item + other = 1 - side + + # Determine shared tag keys once we have rows from both sides + if shared_keys is None: + if not buffers[other]: + buffers[side].append((tag, pkt)) + continue + this_keys = set(tag.keys()) + other_keys = set(buffers[other][0][0].keys()) + shared_keys = tuple(sorted(this_keys & other_keys)) + needs_reindex = True + + # One-time re-index of rows buffered before shared_keys + if needs_reindex: + needs_reindex = False + for buf_side in (0, 1): + for j, (bt, _bp) in enumerate(buffers[buf_side]): + btd = bt.as_dict() + k = ( + tuple(btd[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + indexes[buf_side].setdefault(k, []).append(j) + + # Index the new row + td = tag.as_dict() + key = ( + tuple(td[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + row_idx = len(buffers[side]) + buffers[side].append((tag, pkt)) + indexes[side].setdefault(key, []).append(row_idx) + + # Probe the opposite side for matches + for mi in indexes[other].get(key, []): + ot, op = buffers[other][mi] + if side == 0: + await output.send(merge_fn(tag, pkt, ot, op)) + else: + await output.send(merge_fn(ot, op, tag, pkt)) + finally: + await output.close() - tag, pkt = item - other = 1 - side + @staticmethod + async def _rename_sys_tags( + ch_in: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + ch_out: WritableChannel[tuple[TagProtocol, PacketProtocol]], + suffix: str, + block_sep: str, + sys_prefix: str, + ) -> None: + """Read rows and rename system-tag keys by appending the per-input suffix. - # Determine shared tag keys once we have rows from both sides - if shared_keys is None: - if not buffers[other]: - # Other side empty — just buffer this row for later - buffers[side].append((tag, pkt)) - continue + Used as a pre-processing step in ``_staggered_join`` so that + downstream binary joins can pass system tags through without + modification. + """ + from orcapod.core.datagrams import Tag - # We have data from both sides; compute shared keys - this_keys = set(tag.keys()) - other_keys = set(buffers[other][0][0].keys()) - shared_keys = tuple(sorted(this_keys & other_keys)) - needs_reindex = True - - # One-time re-index of all rows buffered before shared_keys - if needs_reindex: - needs_reindex = False - for buf_side in (0, 1): - for j, (bt, _bp) in enumerate(buffers[buf_side]): - btd = bt.as_dict() - k = ( - tuple(btd[sk] for sk in shared_keys) - if shared_keys - else (0,) - ) - indexes[buf_side].setdefault(k, []).append(j) - - # Emit matches for all already-buffered rows across sides - for li, (lt, lp) in enumerate(buffers[0]): - ltd = lt.as_dict() - lk = ( - tuple(ltd[sk] for sk in shared_keys) - if shared_keys - else (0,) - ) - for ri in indexes[1].get(lk, []): - rt, rp = buffers[1][ri] - await output.send( - self._merge_row_pair( - lt, lp, rt, rp, suffixes, block_sep - ) - ) - - # Index the new row - td = tag.as_dict() - key = tuple(td[sk] for sk in shared_keys) if shared_keys else (0,) - row_idx = len(buffers[side]) - buffers[side].append((tag, pkt)) - indexes[side].setdefault(key, []).append(row_idx) - - # Probe the opposite buffer for matches - matching_indices = indexes[other].get(key, []) - for mi in matching_indices: - other_tag, other_pkt = buffers[other][mi] - if side == 0: - merged = self._merge_row_pair( - tag, - pkt, - other_tag, - other_pkt, - suffixes, - block_sep, - ) - else: - merged = self._merge_row_pair( - other_tag, - other_pkt, - tag, - pkt, - suffixes, - block_sep, + try: + async for tag, pkt in ch_in: + sys_tags = tag.system_tags() + if sys_tags: + renamed: dict = {} + for k, v in sys_tags.items(): + new_key = ( + f"{k}{block_sep}{suffix}" + if k.startswith(sys_prefix) + else k ) - await output.send(merged) + renamed[new_key] = v + tag = Tag(tag.as_dict(), system_tags=renamed) + await ch_out.send((tag, pkt)) + finally: + await ch_out.close() @staticmethod - def _merge_row_pair( + def _merge_pair_rename( left_tag: TagProtocol, left_pkt: PacketProtocol, right_tag: TagProtocol, @@ -431,48 +551,150 @@ def _merge_row_pair( suffixes: list[str], block_sep: str, ) -> tuple[TagProtocol, PacketProtocol]: - """Merge a matched pair of rows into one joined (Tag, Packet). + """Merge a matched pair, renaming system tags with per-side suffixes. - System-tag keys are renamed by appending - ``{block_sep}{suffix}`` to match the canonical name-extending - scheme used by ``static_process``. System-tag values sharing - the same provenance path are sorted for commutativity. + Used for direct 2-input joins where system tags are renamed + during the merge (not pre-renamed). """ from orcapod.core.datagrams import Packet, Tag sys_prefix = constants.SYSTEM_TAG_PREFIX - # Merge tag dicts (shared keys come from left) + # Merge tag dicts — shared keys come from left merged_tag_d: dict = {} - merged_tag_d.update(left_tag.as_dict()) + for k, v in left_tag.as_dict().items(): + merged_tag_d[k] = v for k, v in right_tag.as_dict().items(): if k not in merged_tag_d: merged_tag_d[k] = v - # Rename and merge system tags with canonical suffixes + # Rename and merge system tags merged_sys: dict = {} - for k, v in left_tag.system_tags().items(): - new_key = f"{k}{block_sep}{suffixes[0]}" if k.startswith(sys_prefix) else k - merged_sys[new_key] = v - for k, v in right_tag.system_tags().items(): - new_key = f"{k}{block_sep}{suffixes[1]}" if k.startswith(sys_prefix) else k - merged_sys[new_key] = v + for i, tag in enumerate((left_tag, right_tag)): + for k, v in tag.system_tags().items(): + new_key = ( + f"{k}{block_sep}{suffixes[i]}" + if k.startswith(sys_prefix) + else k + ) + merged_sys[new_key] = v + merged_sys = Join._sort_merged_system_tags(merged_sys) merged_tag = Tag(merged_tag_d, system_tags=merged_sys) # Merge packet dicts (non-overlapping by Join's validation) merged_pkt_d: dict = {} + merged_si: dict = {} merged_pkt_d.update(left_pkt.as_dict()) merged_pkt_d.update(right_pkt.as_dict()) + merged_si.update(left_pkt.source_info()) + merged_si.update(right_pkt.source_info()) + + merged_pkt = Packet(merged_pkt_d, source_info=merged_si) + return merged_tag, merged_pkt + @staticmethod + def _merge_pair_passthrough( + left_tag: TagProtocol, + left_pkt: PacketProtocol, + right_tag: TagProtocol, + right_pkt: PacketProtocol, + ) -> tuple[TagProtocol, PacketProtocol]: + """Merge a matched pair, passing system tags through without renaming. + + Used in the staggered chain where system tags have already been + pre-renamed by ``_rename_sys_tags``. + """ + from orcapod.core.datagrams import Packet, Tag + + # Merge tag dicts — shared keys come from left + merged_tag_d: dict = {} + for k, v in left_tag.as_dict().items(): + merged_tag_d[k] = v + for k, v in right_tag.as_dict().items(): + if k not in merged_tag_d: + merged_tag_d[k] = v + + # Combine system tags (already renamed) + merged_sys: dict = {} + merged_sys.update(left_tag.system_tags()) + merged_sys.update(right_tag.system_tags()) + + # Sort within same-provenance-path groups for commutativity + merged_sys = Join._sort_merged_system_tags(merged_sys) + merged_tag = Tag(merged_tag_d, system_tags=merged_sys) + + # Merge packet dicts (non-overlapping by Join's validation) + merged_pkt_d: dict = {} merged_si: dict = {} + merged_pkt_d.update(left_pkt.as_dict()) + merged_pkt_d.update(right_pkt.as_dict()) merged_si.update(left_pkt.source_info()) merged_si.update(right_pkt.source_info()) merged_pkt = Packet(merged_pkt_d, source_info=merged_si) - return merged_tag, merged_pkt + @staticmethod + def _sort_merged_system_tags(merged_sys: dict) -> dict: + """Sort system tag values within same-provenance-path groups. + + When two joined inputs share a pipeline_hash, their system tag + columns share a provenance path but occupy different canonical + positions. Sorting the paired (source_id, record_id) values + across positions ensures commutativity — mirroring what + ``sort_system_tag_values`` does on Arrow tables in + ``static_process``. + """ + sys_prefix = constants.SYSTEM_TAG_PREFIX + block_sep = constants.BLOCK_SEPARATOR + field_sep = constants.FIELD_SEPARATOR + + # Parse keys → groups[provenance_path][position][field_type] = key + groups: dict[str, dict[str, dict[str, str]]] = {} + for key in merged_sys: + if not key.startswith(sys_prefix): + continue + base, sep, position = key.rpartition(field_sep) + if not sep or not position.isdigit(): + continue + after_prefix = base[len(sys_prefix) :] + field_type, bsep, prov_path = after_prefix.partition(block_sep) + if not bsep: + continue + groups.setdefault(prov_path, {}).setdefault(position, {})[ + field_type + ] = key + + sid_field = constants.SYSTEM_TAG_SOURCE_ID_PREFIX[len(sys_prefix) :] + rid_field = constants.SYSTEM_TAG_RECORD_ID_PREFIX[len(sys_prefix) :] + + for _prov_path, positions in groups.items(): + if len(positions) <= 1: + continue + + sorted_pos_keys = sorted(positions.keys(), key=int) + + # Collect (sort_key, {field_type: value}) per position + entries: list[tuple[tuple, dict[str, object]]] = [] + for pos in sorted_pos_keys: + fmap = positions[pos] + sid_val = merged_sys.get(fmap.get(sid_field, "")) + rid_val = merged_sys.get(fmap.get(rid_field, "")) + vals = {ft: merged_sys[k] for ft, k in fmap.items()} + entries.append(((sid_val or "", rid_val or ""), vals)) + + entries.sort(key=lambda e: e[0]) + + # Write sorted values back to the original position keys + for pos, (_, sorted_vals) in zip(sorted_pos_keys, entries): + fmap = positions[pos] + for field_type, key in fmap.items(): + if field_type in sorted_vals: + merged_sys[key] = sorted_vals[field_type] + + return merged_sys + def identity_structure(self) -> Any: return self.__class__.__name__ diff --git a/tests/test_channels/test_native_async_operators.py b/tests/test_channels/test_native_async_operators.py index de38bf71..5339c941 100644 --- a/tests/test_channels/test_native_async_operators.py +++ b/tests/test_channels/test_native_async_operators.py @@ -14,7 +14,7 @@ - MapPackets streaming: per-row packet column renaming - Batch streaming: accumulate-and-emit full batches, partial batch handling - SemiJoin build-probe: collect right, stream left through hash lookup -- Join: single-input passthrough, concurrent binary/N-ary collection +- Join: single-input passthrough, streaming N-way MJoin - Sync / async equivalence for every operator - Empty input handling - Multi-stage pipeline integration @@ -955,7 +955,7 @@ async def test_large_input_streaming(self): class TestJoinNativeAsync: - """Tests for Join.async_execute (symmetric hash join + N>2 barrier).""" + """Tests for Join.async_execute (N-way streaming MJoin).""" @pytest.mark.asyncio async def test_single_input_passthrough(self): @@ -1103,6 +1103,405 @@ async def test_matches_sync_two_way(self): ) assert async_data == sync_data + @pytest.mark.asyncio + async def test_matches_sync_three_way(self): + """Three-way MJoin must produce the same data as sync static_process.""" + t1 = pa.table( + {"id": pa.array([1, 2, 3], type=pa.int64()), "a": pa.array([10, 20, 30], type=pa.int64())} + ) + t2 = pa.table( + {"id": pa.array([1, 2, 3], type=pa.int64()), "b": pa.array([100, 200, 300], type=pa.int64())} + ) + t3 = pa.table( + {"id": pa.array([1, 2, 3], type=pa.int64()), "c": pa.array([1000, 2000, 3000], type=pa.int64())} + ) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + s3 = ArrowTableStream(t3, tag_columns=["id"]) + + op = Join() + sync_results = sync_process_to_rows(op, s1, s2, s3) + + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + await op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + async_results = await out.reader.collect() + + assert len(async_results) == len(sync_results) + async_data = sorted( + (t.as_dict()["id"], p.as_dict()) for t, p in async_results + ) + sync_data = sorted( + (t.as_dict()["id"], p.as_dict()) for t, p in sync_results + ) + assert async_data == sync_data + + @pytest.mark.asyncio + async def test_three_way_streams_before_all_closed(self): + """MJoin emits matched rows before all input channels are closed. + + This is the key behavioral difference from the old collect-based + fallback: downstream can start work as soon as all N sides have + contributed a matching row for a tag key. + """ + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + + from orcapod.core.datagrams import Packet, Tag + + # Start the join in the background + join_task = asyncio.create_task( + op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + ) + + # Send one matching row from each side + await ch1.writer.send((Tag({"id": 1}), Packet({"a": 10}))) + await ch2.writer.send((Tag({"id": 1}), Packet({"b": 100}))) + await ch3.writer.send((Tag({"id": 1}), Packet({"c": 1000}))) + + # The match should be emitted while channels are still open + tag, pkt = await asyncio.wait_for(out.reader.receive(), timeout=2.0) + assert tag.as_dict()["id"] == 1 + assert pkt.as_dict() == {"a": 10, "b": 100, "c": 1000} + + # Close all inputs and let join finish + await ch1.writer.close() + await ch2.writer.close() + await ch3.writer.close() + await join_task + + @pytest.mark.asyncio + async def test_four_way_join(self): + """Four-input MJoin produces correct results.""" + tables = [ + pa.table({"id": pa.array([1, 2], type=pa.int64()), f"v{i}": pa.array([i * 10, i * 20], type=pa.int64())}) + for i in range(4) + ] + streams = [ArrowTableStream(t, tag_columns=["id"]) for t in tables] + + op = Join() + channels = [Channel(buffer_size=64) for _ in range(4)] + out = Channel(buffer_size=64) + for s, ch in zip(streams, channels): + await feed(s, ch) + await op.async_execute( + [ch.reader for ch in channels], out.writer + ) + results = await out.reader.collect() + + assert len(results) == 2 + result_map = {tag.as_dict()["id"]: pkt.as_dict() for tag, pkt in results} + assert result_map[1] == {"v0": 0, "v1": 10, "v2": 20, "v3": 30} + assert result_map[2] == {"v0": 0, "v1": 20, "v2": 40, "v3": 60} + + @pytest.mark.asyncio + async def test_three_way_partial_match_no_premature_emit(self): + """Rows should not be emitted until all N sides have a match.""" + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + + from orcapod.core.datagrams import Packet, Tag + + join_task = asyncio.create_task( + op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + ) + + # Send matching rows from only 2 of 3 sides + await ch1.writer.send((Tag({"id": 1}), Packet({"a": 10}))) + await ch2.writer.send((Tag({"id": 1}), Packet({"b": 100}))) + + # Output should be empty — side 3 hasn't contributed yet + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(out.reader.receive(), timeout=0.05) + + # Now send the third side — match should complete + await ch3.writer.send((Tag({"id": 1}), Packet({"c": 1000}))) + tag, pkt = await asyncio.wait_for(out.reader.receive(), timeout=2.0) + assert pkt.as_dict() == {"a": 10, "b": 100, "c": 1000} + + await ch1.writer.close() + await ch2.writer.close() + await ch3.writer.close() + await join_task + + @pytest.mark.asyncio + async def test_three_way_empty_side_produces_nothing(self): + """If any input channel is empty, join produces no output.""" + t1 = pa.table({"id": pa.array([1], type=pa.int64()), "a": pa.array([10], type=pa.int64())}) + t2 = pa.table({"id": pa.array([1], type=pa.int64()), "b": pa.array([100], type=pa.int64())}) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await ch3.writer.close() # empty third input + await op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + results = await out.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_three_way_no_shared_tags_cartesian(self): + """Three-way join with disjoint tag keys produces cartesian product.""" + t1 = pa.table({"a": pa.array([1, 2], type=pa.int64()), "x": pa.array([10, 20], type=pa.int64())}) + t2 = pa.table({"b": pa.array([3], type=pa.int64()), "y": pa.array([30], type=pa.int64())}) + t3 = pa.table({"c": pa.array([5, 6], type=pa.int64()), "z": pa.array([50, 60], type=pa.int64())}) + s1 = ArrowTableStream(t1, tag_columns=["a"]) + s2 = ArrowTableStream(t2, tag_columns=["b"]) + s3 = ArrowTableStream(t3, tag_columns=["c"]) + + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out_ch = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + await op.async_execute([ch1.reader, ch2.reader, ch3.reader], out_ch.writer) + results = await out_ch.reader.collect() + + # 2 × 1 × 2 = 4 cartesian product + assert len(results) == 4 + + @pytest.mark.asyncio + async def test_duplicate_tag_keys_cross_product(self): + """Multiple rows per tag key per side should produce a cross-product. + + Side 0 has 2 rows with id=1, side 1 has 3 rows with id=1. + The join should emit 2 × 3 = 6 rows. + """ + t1 = pa.table( + { + "id": pa.array([1, 1], type=pa.int64()), + "a": pa.array([10, 11], type=pa.int64()), + } + ) + t2 = pa.table( + { + "id": pa.array([1, 1, 1], type=pa.int64()), + "b": pa.array([100, 101, 102], type=pa.int64()), + } + ) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + + op = Join() + + # Sync reference + sync_results = sync_process_to_rows(op, s1, s2) + + # Async + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await op.async_execute([ch1.reader, ch2.reader], out.writer) + async_results = await out.reader.collect() + + assert len(async_results) == 6 + assert len(async_results) == len(sync_results) + + async_data = sorted( + (p.as_dict()["a"], p.as_dict()["b"]) for _, p in async_results + ) + sync_data = sorted( + (p.as_dict()["a"], p.as_dict()["b"]) for _, p in sync_results + ) + assert async_data == sync_data + + @pytest.mark.asyncio + async def test_three_way_duplicate_keys_cross_product(self): + """Three-way join with duplicate tag keys produces correct cross-product. + + Side 0: 2 rows, side 1: 1 row, side 2: 2 rows → 2 × 1 × 2 = 4 rows. + """ + t1 = pa.table( + { + "id": pa.array([1, 1], type=pa.int64()), + "a": pa.array([10, 11], type=pa.int64()), + } + ) + t2 = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "b": pa.array([100], type=pa.int64()), + } + ) + t3 = pa.table( + { + "id": pa.array([1, 1], type=pa.int64()), + "c": pa.array([1000, 1001], type=pa.int64()), + } + ) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + s3 = ArrowTableStream(t3, tag_columns=["id"]) + + op = Join() + sync_results = sync_process_to_rows(op, s1, s2, s3) + + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + await op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + async_results = await out.reader.collect() + + assert len(async_results) == 4 + assert len(async_results) == len(sync_results) + + async_data = sorted( + (p.as_dict()["a"], p.as_dict()["b"], p.as_dict()["c"]) + for _, p in async_results + ) + sync_data = sorted( + (p.as_dict()["a"], p.as_dict()["b"], p.as_dict()["c"]) + for _, p in sync_results + ) + assert async_data == sync_data + + @pytest.mark.asyncio + async def test_partially_overlapping_tags_matches_sync(self): + """Staggered join with partially overlapping tags matches sync. + + S0: tag={a}, S1: tag={b}, S2: tag={a}. + static_process joins iteratively: (S0 ⋈ S1) cartesian, then + result ⋈ S2 on shared tag 'a'. The async staggered chain must + produce the same 2 rows (not 4 from a full cartesian). + """ + t1 = pa.table( + { + "a": pa.array([1, 2], type=pa.int64()), + "x": pa.array([10, 20], type=pa.int64()), + } + ) + t2 = pa.table( + { + "b": pa.array([5], type=pa.int64()), + "y": pa.array([50], type=pa.int64()), + } + ) + t3 = pa.table( + { + "a": pa.array([1, 2], type=pa.int64()), + "z": pa.array([100, 200], type=pa.int64()), + } + ) + s1 = ArrowTableStream(t1, tag_columns=["a"]) + s2 = ArrowTableStream(t2, tag_columns=["b"]) + s3 = ArrowTableStream(t3, tag_columns=["a"]) + + op = Join() + sync_results = sync_process_to_rows(op, s1, s2, s3) + + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + hashes = [s1.pipeline_hash(), s2.pipeline_hash(), s3.pipeline_hash()] + await op.async_execute( + [ch1.reader, ch2.reader, ch3.reader], + out.writer, + input_pipeline_hashes=hashes, + ) + async_results = await out.reader.collect() + + # Sync produces 2 rows (constrained on 'a'); async must match + assert len(sync_results) == 2 + assert len(async_results) == len(sync_results) + + async_data = sorted( + (t.as_dict()["a"], p.as_dict()) for t, p in async_results + ) + sync_data = sorted( + (t.as_dict()["a"], p.as_dict()) for t, p in sync_results + ) + assert async_data == sync_data + + @pytest.mark.asyncio + async def test_input_pipeline_hashes_length_mismatch(self): + """async_execute raises ValueError if input_pipeline_hashes length != inputs.""" + t1 = pa.table({"id": pa.array([1], type=pa.int64()), "a": pa.array([10], type=pa.int64())}) + t2 = pa.table({"id": pa.array([1], type=pa.int64()), "b": pa.array([20], type=pa.int64())}) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + with pytest.raises(ValueError, match="must match inputs length"): + await op.async_execute( + [ch1.reader, ch2.reader], + out.writer, + input_pipeline_hashes=[s1.pipeline_hash()], # only 1, need 2 + ) + + @pytest.mark.asyncio + async def test_buffered_rows_both_sides_emit_on_reindex(self): + """Matches buffered on both sides are emitted when shared_keys is first computed. + + Sends multiple rows to one side before any row from the other, + so both sides have buffered rows when shared_keys is determined. + """ + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + + from orcapod.core.datagrams import Packet, Tag + + join_task = asyncio.create_task( + op.async_execute([ch1.reader, ch2.reader], out.writer) + ) + + # Buffer 2 rows on side 0 before side 1 sends anything + await ch1.writer.send((Tag({"id": 1}), Packet({"a": 10}))) + await ch1.writer.send((Tag({"id": 2}), Packet({"a": 20}))) + + # Give the event loop a chance to process + await asyncio.sleep(0.01) + + # Now side 1 sends — triggers shared_keys computation and re-index + await ch2.writer.send((Tag({"id": 1}), Packet({"b": 100}))) + await ch2.writer.send((Tag({"id": 2}), Packet({"b": 200}))) + + # Close and collect + await ch1.writer.close() + await ch2.writer.close() + await join_task + results = await out.reader.collect() + + assert len(results) == 2 + result_map = {t.as_dict()["id"]: p.as_dict() for t, p in results} + assert result_map[1] == {"a": 10, "b": 100} + assert result_map[2] == {"a": 20, "b": 200} + # =================================================================== # Multi-stage pipeline integration @@ -1400,7 +1799,7 @@ async def test_commutativity_system_tags_identical(self): @pytest.mark.asyncio async def test_three_way_system_tags_match_sync(self): - """N>2 barrier fallback should produce the same system tags as sync.""" + """N-way MJoin should produce the same system tags as sync.""" s1 = _make_source("id", "a", {"id": ["m", "n"], "a": [1, 2]}) s2 = _make_source("id", "b", {"id": ["m", "n"], "b": [10, 20]}) s3 = _make_source("id", "c", {"id": ["m", "n"], "c": [100, 200]}) @@ -1410,7 +1809,7 @@ async def test_three_way_system_tags_match_sync(self): sync_result = op.static_process(s1, s2, s3) sync_rows = list(sync_result.iter_packets()) - # Async (N>2 barrier path) + # Async (N-way MJoin path) op.validate_inputs(s1, s2, s3) ch1 = Channel(buffer_size=64) ch2 = Channel(buffer_size=64) @@ -1438,6 +1837,48 @@ async def test_three_way_system_tags_match_sync(self): assert sync_sys == async_sys +class TestSortMergedSystemTags: + """Unit tests for Join._sort_merged_system_tags.""" + + def test_sorts_same_provenance_path_by_value(self): + """System tag values at different positions within the same provenance + path should be sorted by (source_id, record_id) tuple.""" + merged_sys = { + # Provenance path "abc123", position 0 — higher values + "_tag_source_id::abc123:0": "z_source", + "_tag_record_id::abc123:0": "z_record", + # Provenance path "abc123", position 1 — lower values + "_tag_source_id::abc123:1": "a_source", + "_tag_record_id::abc123:1": "a_record", + } + result = Join._sort_merged_system_tags(merged_sys) + + # After sorting, position 0 should have the smaller values + assert result["_tag_source_id::abc123:0"] == "a_source" + assert result["_tag_record_id::abc123:0"] == "a_record" + assert result["_tag_source_id::abc123:1"] == "z_source" + assert result["_tag_record_id::abc123:1"] == "z_record" + + def test_single_position_unchanged(self): + """Groups with only one position should not be modified.""" + merged_sys = { + "_tag_source_id::abc123:0": "only_source", + "_tag_record_id::abc123:0": "only_record", + } + result = Join._sort_merged_system_tags(merged_sys) + assert result == merged_sys + + def test_non_system_tag_keys_ignored(self): + """Non-system-tag keys should pass through unchanged.""" + merged_sys = { + "regular_key": "value", + "_tag_source_id::abc123:0": "src0", + "_tag_record_id::abc123:0": "rec0", + } + result = Join._sort_merged_system_tags(merged_sys) + assert result["regular_key"] == "value" + + class TestSemiJoinSystemTagEquivalence: """Verify SemiJoin system-tag handling matches between sync and async."""