diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 945d6279..9dcab19b 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -315,6 +315,25 @@ which column groups (meta, source, system_tags) are returned. --- +## `src/orcapod/core/cached_function_pod.py` / `src/orcapod/core/packet_function.py` + +### CFP1 — Extract shared result caching logic from CachedPacketFunction and CachedFunctionPod +**Status:** resolved +**Severity:** medium + +`CachedPacketFunction` and `CachedFunctionPod` implement nearly identical result caching +logic: DB lookup by `INPUT_PACKET_HASH_COL`, conflict resolution by most-recent timestamp, +record storage with variation/execution/timestamp columns, and a `RESULT_COMPUTED_FLAG` +meta column. The match tier / matching policy design (P6) will also need to apply to both. + +**Fix:** Extracted `ResultCache` class (`src/orcapod/core/result_cache.py`) that owns the DB, +record path, lookup (with `additional_constraints` for future match tiers), store, conflict +resolution, and auto-flush logic. Both `CachedPacketFunction` and `CachedFunctionPod` now +delegate to a `ResultCache` instance. The match tier strategy (P6) can be implemented once +on `ResultCache.lookup` via `additional_constraints`. + +--- + ## `src/orcapod/core/nodes/function_node.py` ### FN1 — `FunctionNode.async_execute` Phase 2 was fully sequential diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py new file mode 100644 index 00000000..76d2a27d --- /dev/null +++ b/src/orcapod/core/cached_function_pod.py @@ -0,0 +1,154 @@ +"""CachedFunctionPod — pod-level caching wrapper that intercepts process_packet().""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.core.function_pod import WrappedFunctionPod +from orcapod.core.result_cache import ResultCache +from orcapod.protocols.core_protocols import ( + FunctionPodProtocol, + PacketProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +if TYPE_CHECKING: + import pyarrow as pa + +logger = logging.getLogger(__name__) + + +class CachedFunctionPod(WrappedFunctionPod): + """Pod-level caching wrapper that intercepts ``process_packet()``. + + Caches at the ``process_packet(tag, packet)`` level using only the + **input packet content hash** as the cache key — the output of a + packet function depends solely on the packet, not the tag. + + Tag-level provenance tracking (tag + system tags + packet hash) is + handled separately by ``FunctionNode.add_pipeline_record``. + + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedPacketFunction``). + """ + + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG + + def __init__( + self, + function_pod: FunctionPodProtocol, + result_database: ArrowDatabaseProtocol, + record_path_prefix: tuple[str, ...] = (), + auto_flush: bool = True, + **kwargs, + ) -> None: + super().__init__(function_pod, **kwargs) + self._record_path_prefix = record_path_prefix + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=auto_flush, + ) + + @property + def _result_database(self) -> ArrowDatabaseProtocol: + """The underlying result database (for FunctionNode access).""" + return self._cache.result_database + + @property + def record_path(self) -> tuple[str, ...]: + """Return the path to the cached records in the result store.""" + return self._cache.record_path + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a packet with pod-level caching. + + The cache key is the input packet content hash only — the function + output depends solely on the packet, not the tag. The output + packet carries a ``RESULT_COMPUTED_FLAG`` meta value: ``True`` if + freshly computed, ``False`` if retrieved from cache. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple; output_packet is ``None`` + if the inner function filters the packet out. + """ + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached + + tag, output = self._function_pod.process_packet(tag, packet) + if output is not None: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output + + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + DB lookup and store are synchronous (DB protocol is sync), but the + actual computation uses the inner pod's ``async_process_packet`` + for true async execution. + """ + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached + + tag, output = await self._function_pod.async_process_packet(tag, packet) + if output is not None: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output + + def get_all_cached_outputs( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store for this pod.""" + return self._cache.get_all_records( + include_system_columns=include_system_columns + ) + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """Invoke the inner pod but with pod-level caching on process_packet. + + The stream returned uses *this* pod's ``process_packet`` (which + includes caching) rather than the inner pod's. + """ + from orcapod.core.function_pod import FunctionPodStream + + # Validate and prepare the input stream + input_stream = self._function_pod.handle_input_streams(*streams) + self._function_pod.validate_inputs(*streams) + + return FunctionPodStream( + function_pod=self, + input_stream=input_stream, + label=label, + ) diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py index 8bb2fd23..386cbe1a 100644 --- a/src/orcapod/core/executors/base.py +++ b/src/orcapod/core/executors/base.py @@ -1,6 +1,8 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod +from collections.abc import Callable from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -77,13 +79,52 @@ def supports_concurrent_execution(self) -> bool: return False def with_options(self, **opts: Any) -> "PacketFunctionExecutorBase": - """Return an executor configured with the given per-node options. + """Return a **new** executor instance configured with the given per-node options. - The default implementation ignores *opts* and returns *self*. - Subclasses that support resource options (e.g. ``RayExecutor``) - should override to return a new instance with the merged options. + The default implementation returns a shallow copy of *self*. + Subclasses that carry mutable state (e.g. ``RayExecutor``) should + override to produce a properly configured new instance. """ - return self + return copy.copy(self) + + # ------------------------------------------------------------------ + # Callable-level execution (PythonFunctionExecutorProtocol) + # ------------------------------------------------------------------ + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Synchronously execute *fn* with *kwargs*. + + Default implementation calls ``fn(**kwargs)`` in-process. + Subclasses should override for remote/distributed execution. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + The raw return value of *fn*. + """ + return fn(**kwargs) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Asynchronously execute *fn* with *kwargs*. + + Default implementation delegates to ``execute_callable`` + synchronously. Subclasses should override for truly async + execution. + """ + return self.execute_callable(fn, kwargs, executor_options) def get_execution_data(self) -> dict[str, Any]: """Return metadata describing the execution environment. diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index 92289955..fc97e477 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import asyncio +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from orcapod.core.executors.base import PacketFunctionExecutorBase @@ -35,3 +38,46 @@ async def async_execute( packet: PacketProtocol, ) -> PacketProtocol | None: return await packet_function.direct_async_call(packet) + + # -- PythonFunctionExecutorProtocol -- + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + if inspect.iscoroutinefunction(fn): + return self._run_async_sync(fn, kwargs) + return fn(**kwargs) + + @staticmethod + def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: + """Run an async function synchronously, handling nested event loops.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(fn(**kwargs)) + else: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(1) as pool: + return pool.submit(lambda: asyncio.run(fn(**kwargs))).result() + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + if inspect.iscoroutinefunction(fn): + return await fn(**kwargs) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: fn(**kwargs)) + + def with_options(self, **opts: Any) -> "LocalExecutor": + """Return a new ``LocalExecutor``. + + ``LocalExecutor`` carries no state, so options are ignored. + """ + return LocalExecutor() diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index 9f54812e..998c2193 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable from typing import TYPE_CHECKING, Any from orcapod.core.executors.base import PacketFunctionExecutorBase @@ -157,6 +158,34 @@ async def async_execute( raw_result = await asyncio.wrap_future(ref.future()) return pf._build_output_packet(raw_result) + # -- PythonFunctionExecutorProtocol -- + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + import ray + + self._ensure_ray_initialized() + remote_fn = ray.remote(**self._build_remote_opts())(fn) + ref = remote_fn.remote(**kwargs) + return ray.get(ref) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + import ray + + self._ensure_ray_initialized() + remote_fn = ray.remote(**self._build_remote_opts())(fn) + ref = remote_fn.remote(**kwargs) + return await asyncio.wrap_future(ref.future()) + def with_options(self, **opts: Any) -> "RayExecutor": """Return a new ``RayExecutor`` with the given options merged in. diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 2272c616..5a1db639 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -652,6 +652,7 @@ def function_pod( version: str = "v0.0", label: str | None = None, result_database: ArrowDatabaseProtocol | None = None, + pod_cache_database: ArrowDatabaseProtocol | None = None, executor: PacketFunctionExecutorProtocol | None = None, **kwargs, ) -> Callable[..., CallableWithPod]: @@ -662,7 +663,11 @@ def function_pod( function_name: Name of the function pod; defaults to ``func.__name__``. version: Version string for the packet function. label: Optional label for tracking. - result_database: Optional database for caching results. + result_database: Optional database for packet-level caching + (wraps the packet function in ``CachedPacketFunction``). + pod_cache_database: Optional database for pod-level caching + (wraps the pod in ``CachedFunctionPod``, which caches at the + ``process_packet`` level using input packet content hash). executor: Optional executor for running the packet function. **kwargs: Forwarded to ``PythonPacketFunction``. @@ -692,10 +697,19 @@ def decorator(func: Callable) -> CallableWithPod: ) # Create a simple typed function pod - pod = FunctionPod( + pod: _FunctionPodBase = FunctionPod( packet_function=packet_function, ) + # if pod_cache_database is provided, wrap in CachedFunctionPod + if pod_cache_database is not None: + from orcapod.core.cached_function_pod import CachedFunctionPod + + pod = CachedFunctionPod( + function_pod=pod, + result_database=pod_cache_database, + ) + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index 140c7c7b..f451a59f 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -10,7 +10,7 @@ from orcapod import contexts from orcapod.channels import ReadableChannel, WritableChannel from orcapod.config import Config -from orcapod.core.packet_function import CachedPacketFunction +from orcapod.core.cached_function_pod import CachedFunctionPod from orcapod.core.streams.arrow_table_stream import ArrowTableStream from orcapod.core.streams.base import StreamBase from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER @@ -60,7 +60,7 @@ class FunctionNode(StreamBase): When constructed without database parameters, provides the core stream interface (identity, schema, iteration) without any persistence. When databases are provided (either at construction or via ``attach_databases``), - adds result caching via ``CachedPacketFunction``, pipeline record storage, + adds result caching via ``CachedFunctionPod``, pipeline record storage, and two-phase iteration (cached first, then compute missing). """ @@ -101,11 +101,6 @@ def __init__( self._input_stream = input_stream - # Per-node Ray (or other engine) resource overrides. When a pipeline - # is run with an execution_engine, these opts are merged on top of the - # pipeline-level execution_engine_opts for this node only. - self.execution_engine_opts: dict[str, Any] | None = None - # stream-level caching state (iterator acquired lazily on first use) self._cached_input_iterator: ( Iterator[tuple[TagProtocol, PacketProtocol]] | None @@ -119,6 +114,7 @@ def __init__( # DB persistence state (initially None; set via __init__ params or attach_databases) self._pipeline_database: ArrowDatabaseProtocol | None = None + self._cached_function_pod: CachedFunctionPod | None = None self._pipeline_path_prefix: tuple[str, ...] = () self._pipeline_node_hash: str | None = None self._output_schema_hash: str | None = None @@ -144,6 +140,10 @@ def attach_databases( ) -> None: """Attach databases for persistent caching and pipeline records. + Creates a ``CachedFunctionPod`` wrapping the original function pod + for result caching. The pipeline database is used separately for + pipeline-level provenance records (tag + packet hash). + Args: pipeline_database: Database for pipeline records. result_database: Database for cached results. Defaults to @@ -162,13 +162,9 @@ def attach_databases( elif result_path_prefix is not None: computed_result_path_prefix = result_path_prefix - # Guard against double-wrapping - pf = self._packet_function - if isinstance(pf, CachedPacketFunction): - pf = pf._packet_function - - self._packet_function = CachedPacketFunction( - pf, + # Always wrap the original function_pod (not a previous cached wrapper) + self._cached_function_pod = CachedFunctionPod( + self._function_pod, result_database=result_database, record_path_prefix=computed_result_path_prefix, ) @@ -282,7 +278,6 @@ def from_descriptor( node._packet_function = None node._input_stream = None node.tracker_manager = DEFAULT_TRACKER_MANAGER - node.execution_engine_opts = descriptor.get("execution_engine_opts") node._cached_input_iterator = None node._needs_iterator = True node._cached_output_packets = {} @@ -291,6 +286,7 @@ def from_descriptor( # DB persistence state node._pipeline_database = pipeline_db + node._cached_function_pod = None node._pipeline_path_prefix = () node._pipeline_node_hash = None node._output_schema_hash = None @@ -473,33 +469,29 @@ def process_packet( self, tag: TagProtocol, packet: PacketProtocol, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a single packet, optionally recording to the pipeline database. + When a database is attached, uses ``CachedFunctionPod`` for result + caching (tag+packet level) and separately records a pipeline + provenance entry. + Args: tag: The tag associated with the packet. packet: The input packet to process. - skip_cache_lookup: If True, bypass DB lookup for existing result. - skip_cache_insert: If True, skip writing result to DB. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the function filters the packet out. """ - if self._pipeline_database is not None: - # Persistent mode: use CachedPacketFunction + pipeline record - output_packet = self._packet_function.call( - packet, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, - ) + if self._cached_function_pod is not None: + # Persistent mode: CachedFunctionPod for result caching + pipeline record + tag, output_packet = self._cached_function_pod.process_packet(tag, packet) if output_packet is not None: result_computed = bool( output_packet.get_meta_value( - self._packet_function.RESULT_COMPUTED_FLAG, False + self._cached_function_pod.RESULT_COMPUTED_FLAG, False ) ) self.add_pipeline_record( @@ -517,26 +509,22 @@ async def async_process_packet( self, tag: TagProtocol, packet: PacketProtocol, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``. - Uses the CachedPacketFunction's async_call for computation + result - caching when a database is attached. Pipeline record storage is - synchronous (DB protocol is sync). + Uses ``CachedFunctionPod.async_process_packet`` (sync DB caching + + async computation) when a database is attached. Pipeline record + storage is synchronous (DB protocol is sync). """ - if self._pipeline_database is not None: - output_packet = await self._packet_function.async_call( - packet, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, + if self._cached_function_pod is not None: + tag, output_packet = await self._cached_function_pod.async_process_packet( + tag, packet ) if output_packet is not None: result_computed = bool( output_packet.get_meta_value( - self._packet_function.RESULT_COMPUTED_FLAG, False + self._cached_function_pod.RESULT_COMPUTED_FLAG, False ) ) self.add_pipeline_record( @@ -550,6 +538,27 @@ async def async_process_packet( else: return await self._function_pod.async_process_packet(tag, packet) + def compute_pipeline_entry_id( + self, tag: TagProtocol, input_packet: PacketProtocol + ) -> str: + """Compute a unique pipeline entry ID from tag + system tags + input packet hash. + + This ID uniquely identifies a (tag, system_tags, input_packet) combination + and is used as the record ID in the pipeline database. + + Args: + tag: The tag (including system tags). + input_packet: The input packet. + + Returns: + A hash string uniquely identifying this combination. + """ + tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() + def add_pipeline_record( self, tag: TagProtocol, @@ -558,19 +567,18 @@ def add_pipeline_record( computed: bool, skip_cache_lookup: bool = False, ) -> None: - """Add a pipeline record to the database for a processed packet.""" - # combine TagProtocol with packet content hash to compute entry hash - # TODO: add system tag columns - # TODO: consider using bytes instead of string representation - tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( - constants.INPUT_PACKET_HASH_COL, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # unique entry ID is determined by the combination of tags, system_tags, and input_packet hash - entry_id = self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() + """Add a pipeline record to the database for a processed packet. + + The pipeline record stores: + - Tag columns (including system tags) + - All source columns of the input packet (provenance, not data) + - Output packet record ID (for joining with result records) + - Input packet data context key + - Whether the result was freshly computed or cached + """ + entry_id = self.compute_pipeline_entry_id(tag, input_packet) - # check presence of an existing entry with the same entry_id + # Check for existing entry existing_record = None if not skip_cache_lookup: existing_record = self._pipeline_database.get_record_by_id( @@ -579,35 +587,40 @@ def add_pipeline_record( ) if existing_record is not None: - # if the record already exists, then skip adding logger.debug( f"Record with entry_id {entry_id} already exists. Skipping addition." ) return - # rename all keys to avoid potential collision with result columns - renamed_input_packet = input_packet.rename( - {k: f"_input_{k}" for k in input_packet.keys()} - ) - input_packet_info = ( - renamed_input_packet.as_table(columns={"source": True}) - .append_column( - constants.PACKET_RECORD_ID, # record ID for the packet function output packet - pa.array([packet_record_id], type=pa.large_string()), - ) - .append_column( - f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}", # data context key for the input packet - pa.array([input_packet.data_context_key], type=pa.large_string()), - ) - .append_column( - f"{constants.META_PREFIX}computed", - pa.array([computed], type=pa.bool_()), - ) - .drop_columns(list(renamed_input_packet.keys())) + # Extract source columns only (no data columns) from the input packet + input_table_with_source = input_packet.as_table(columns={"source": True}) + source_col_names = [ + c + for c in input_table_with_source.column_names + if c.startswith(constants.SOURCE_PREFIX) + ] + input_source_table = input_table_with_source.select(source_col_names) + + # Build the meta columns table + meta_table = pa.table( + { + constants.PACKET_RECORD_ID: pa.array( + [packet_record_id], type=pa.large_string() + ), + f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}": pa.array( + [input_packet.data_context_key], type=pa.large_string() + ), + f"{constants.META_PREFIX}computed": pa.array( + [computed], type=pa.bool_() + ), + } ) + # Combine: tag (with system tags) + input source columns + meta columns combined_record = arrow_utils.hstack_tables( - tag.as_table(columns={"system_tags": True}), input_packet_info + tag.as_table(columns={"system_tags": True}), + input_source_table, + meta_table, ) self._pipeline_database.add_record( @@ -636,11 +649,11 @@ def get_all_records( A PyArrow table of joined results, or ``None`` if no database is attached or no records exist. """ - if self._pipeline_database is None: + if self._cached_function_pod is None: return None - results = self._packet_function._result_database.get_all_records( - self._packet_function.record_path, + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, record_id_column=constants.PACKET_RECORD_ID, ) taginfo = self._pipeline_database.get_all_records(self.pipeline_path) @@ -706,31 +719,70 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: self.clear_cache() self._ensure_iterator() - if self._pipeline_database is not None: + if self._cached_function_pod is not None: # Two-phase iteration with DB backing if self._cached_input_iterator is not None: input_iter = self._cached_input_iterator # --- Phase 1: yield already-computed results from the databases --- - existing = self.get_all_records(columns={"meta": True}) - computed_hashes: set[str] = set() - if existing is not None and existing.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - # Strip the meta column before handing to ArrowTableStream so it only - # sees tag + output-packet columns. - hash_col = constants.INPUT_PACKET_HASH_COL - hash_values = cast(list[str], existing.column(hash_col).to_pylist()) - computed_hashes = set(hash_values) - data_table = existing.drop([hash_col]) - existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for i, (tag, packet) in enumerate(existing_stream.iter_packets()): - self._cached_output_packets[i] = (tag, packet) - yield tag, packet + # Retrieve pipeline records with their entry_ids (record IDs) + # and join with result records to reconstruct (tag, output_packet). + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + # Collect pipeline entry_ids for Phase 2 skip check + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + # Drop internal columns before yielding as stream + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for i, (tag, packet) in enumerate( + existing_stream.iter_packets() + ): + self._cached_output_packets[i] = (tag, packet) + yield tag, packet # --- Phase 2: process only missing input packets --- + # Skip inputs whose pipeline entry_id (tag+system_tags+packet_hash) + # already exists in the pipeline database. next_idx = len(self._cached_output_packets) for tag, packet in input_iter: - input_hash = packet.content_hash().to_string() - if input_hash in computed_hashes: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: continue tag, output_packet = self.process_packet(tag, packet) self._cached_output_packets[next_idx] = (tag, output_packet) @@ -951,21 +1003,53 @@ async def async_execute( node_config = getattr(self._function_pod, "node_config", NodeConfig()) max_concurrency = resolve_concurrency(node_config, pipeline_config) - if self._pipeline_database is not None: + if self._cached_function_pod is not None: # Two-phase async execution with DB backing # Phase 1: emit existing results from DB - existing = self.get_all_records(columns={"meta": True}) - computed_hashes: set[str] = set() - if existing is not None and existing.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - hash_col = constants.INPUT_PACKET_HASH_COL - computed_hashes = set( - cast(list[str], existing.column(hash_col).to_pylist()) + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() ) - data_table = existing.drop([hash_col]) - existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - await output.send((tag, packet)) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) # Phase 2: process new packets concurrently sem = ( @@ -989,8 +1073,8 @@ async def process_one_db( async with asyncio.TaskGroup() as tg: async for tag, packet in inputs[0]: - input_hash = packet.content_hash().to_string() - if input_hash in computed_hashes: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: continue if sem is not None: await sem.acquire() diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index c8b73c7f..bd26ea11 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -7,7 +7,8 @@ from abc import abstractmethod from collections.abc import Callable, Iterable, Sequence from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal +import typing +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar from uuid_utils import uuid7 @@ -19,8 +20,12 @@ get_function_components, get_function_signature, ) +from orcapod.core.result_cache import ResultCache from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol -from orcapod.protocols.core_protocols.executor import PacketFunctionExecutorProtocol +from orcapod.protocols.core_protocols.executor import ( + PacketFunctionExecutorProtocol, + PythonFunctionExecutorProtocol, +) from orcapod.protocols.database_protocols import ArrowDatabaseProtocol from orcapod.system_constants import constants from orcapod.types import DataValue, Schema, SchemaLike @@ -101,8 +106,29 @@ def parse_function_outputs( return dict(zip(output_keys, output_values)) -class PacketFunctionBase(TraceableBase): - """Abstract base class for PacketFunctionProtocol.""" +E = TypeVar("E", bound=PacketFunctionExecutorProtocol) + + +class PacketFunctionBase(TraceableBase, Generic[E]): + """Abstract base class for PacketFunctionProtocol. + + Type-parameterized with the executor protocol ``E``. Concrete + subclasses that bind ``E`` (e.g. ``class Foo(PacketFunctionBase[SomeProto])``) + get automatic ``isinstance`` validation in ``set_executor`` at class + definition time via ``__init_subclass__``. + """ + + _resolved_executor_protocol: ClassVar[type | None] = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for base in getattr(cls, "__orig_bases__", ()): + origin = typing.get_origin(base) + if origin is PacketFunctionBase: + args = typing.get_args(base) + if args and not isinstance(args[0], TypeVar): + cls._resolved_executor_protocol = args[0] + return def __init__( self, @@ -115,7 +141,7 @@ def __init__( super().__init__(label=label, data_context=data_context, config=config) self._active = True self._version = version - self._executor: PacketFunctionExecutorProtocol | None = None + self._executor: E | None = None # Parse version string to extract major and minor versions # 0.5.2 -> 0 and 5.2, 1.3rc -> 1 and 3rc @@ -218,24 +244,43 @@ def get_execution_data(self) -> dict[str, Any]: # ==================== Executor ==================== @property - def executor(self) -> PacketFunctionExecutorProtocol | None: + def executor(self) -> E | None: """Return the executor used to run this packet function, or ``None`` for direct execution.""" return self._executor @executor.setter - def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + def executor(self, executor: E | None) -> None: """Set or clear the executor for this packet function. + Delegates to ``set_executor`` for validation. + """ + self.set_executor(executor) + + def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor, validating type compatibility. + + Performs two checks: + 1. The executor supports this function's ``packet_function_type_id``. + 2. If the subclass bound ``E`` via ``Generic[E]``, the executor is an + instance of the resolved protocol (checked once at assignment time, + not in the hot path). + Raises: - TypeError: If *executor* does not support this function's - ``packet_function_type_id``. + TypeError: If *executor* fails either compatibility check. """ - if executor is not None and not executor.supports(self.packet_function_type_id): - raise TypeError( - f"Executor {executor.executor_type_id!r} does not support " - f"packet function type {self.packet_function_type_id!r}. " - f"Supported types: {executor.supported_function_type_ids()}" - ) + if executor is not None: + if not executor.supports(self.packet_function_type_id): + raise TypeError( + f"Executor {executor.executor_type_id!r} does not support " + f"packet function type {self.packet_function_type_id!r}. " + f"Supported types: {executor.supported_function_type_ids()}" + ) + proto = getattr(type(self), "_resolved_executor_protocol", None) + if proto is not None and not isinstance(executor, proto): + raise TypeError( + f"{type(self).__name__} requires an executor implementing " + f"{proto.__name__}, got {type(executor).__name__}" + ) self._executor = executor # ==================== Execution ==================== @@ -273,7 +318,7 @@ async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | No ... -class PythonPacketFunction(PacketFunctionBase): +class PythonPacketFunction(PacketFunctionBase[PythonFunctionExecutorProtocol]): @property def packet_function_type_id(self) -> str: """Unique function type identifier.""" @@ -463,6 +508,31 @@ def _call_async_function_sync(self, packet: PacketProtocol) -> Any: _get_sync_executor().submit(lambda: asyncio.run(fn(**kwargs))).result() ) + def call(self, packet: PacketProtocol) -> PacketProtocol | None: + """Process a single packet, routing through the executor if one is set. + + When an executor implementing ``PythonFunctionExecutorProtocol`` is + set, the raw callable and kwargs are handed to + ``execute_callable`` and the result is wrapped into an output packet. + """ + if self._executor is not None: + if not self._active: + return None + raw = self._executor.execute_callable(self._function, packet.as_dict()) + return self._build_output_packet(raw) + return self.direct_call(packet) + + async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + """Async counterpart of ``call``.""" + if self._executor is not None: + if not self._active: + return None + raw = await self._executor.async_execute_callable( + self._function, packet.as_dict() + ) + return self._build_output_packet(raw) + return await self.direct_async_call(packet) + def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: """Execute the function on *packet* synchronously. @@ -544,8 +614,13 @@ def from_config(cls, config: dict[str, Any]) -> "PythonPacketFunction": ) -class PacketFunctionWrapper(PacketFunctionBase): - """Wrapper around a PacketFunctionProtocol to modify or extend its behavior.""" +class PacketFunctionWrapper(PacketFunctionBase[E]): + """Wrapper around a PacketFunctionProtocol to modify or extend its behavior. + + Remains generic over ``E`` — the executor protocol is not bound here + so that wrappers inherit the executor type constraint of the wrapped + function. + """ def __init__(self, packet_function: PacketFunctionProtocol, **kwargs) -> None: super().__init__(**kwargs) @@ -650,10 +725,14 @@ async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | No class CachedPacketFunction(PacketFunctionWrapper): - """Wrapper around a PacketFunctionProtocol that caches results for identical input packets.""" + """Wrapper around a PacketFunctionProtocol that caches results for identical input packets. - # cloumn name containing indication of whether the result was computed - RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedFunctionPod``). + """ + + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG def __init__( self, @@ -665,16 +744,20 @@ def __init__( super().__init__(packet_function, **kwargs) self._result_database = result_database self._record_path_prefix = record_path_prefix - self._auto_flush = True + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=True, + ) def set_auto_flush(self, on: bool = True) -> None: """Set auto-flush behavior. If True, the database flushes after each record.""" - self._auto_flush = on + self._cache.set_auto_flush(on) @property def record_path(self) -> tuple[str, ...]: """Return the path to the record in the result store.""" - return self._record_path_prefix + self.uri + return self._cache.record_path def call( self, @@ -683,20 +766,22 @@ def call( skip_cache_lookup: bool = False, skip_cache_insert: bool = False, ) -> PacketProtocol | None: - # execution_engine_hash = execution_engine.name if execution_engine else "default" output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") - # lookup stored result for the input packet - output_packet = self.get_cached_output_for_packet(packet) + output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") if output_packet is None: output_packet = self._packet_function.call(packet) if output_packet is not None: if not skip_cache_insert: - self.record_packet(packet, output_packet) - # add meta column to indicate that this was computed + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) @@ -714,14 +799,19 @@ async def async_call( output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") - output_packet = self.get_cached_output_for_packet(packet) + output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") if output_packet is None: output_packet = await self._packet_function.async_call(packet) if output_packet is not None: if not skip_cache_insert: - self.record_packet(packet, output_packet) + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) @@ -737,45 +827,7 @@ def get_cached_output_for_packet( Returns: The cached output packet, or ``None`` if no entry was found. """ - - # get all records with matching the input packet hash - # TODO: add match based on match_tier if specified - - # TODO: implement matching policy/strategy - constraints = { - constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string() - } - - RECORD_ID_COLUMN = "_record_id" - result_table = self._result_database.get_records_with_column_value( - self.record_path, - constraints, - record_id_column=RECORD_ID_COLUMN, - ) - - if result_table is None or result_table.num_rows == 0: - return None - - if result_table.num_rows > 1: - logger.info( - f"Performing conflict resolution for multiple records for {input_packet.content_hash().display_name()}" - ) - result_table = result_table.sort_by( - [(constants.POD_TIMESTAMP, "descending")] - ).take([0]) - - # extract the record_id column - record_id = result_table.to_pylist()[0][RECORD_ID_COLUMN] - result_table = result_table.drop_columns( - [RECORD_ID_COLUMN, constants.INPUT_PACKET_HASH_COL] - ) - - # note that data context will be loaded from the result store - return Packet( - result_table, - record_id=record_id, - meta_info={self.RESULT_COMPUTED_FLAG: False}, - ) + return self._cache.lookup(input_packet) def record_packet( self, @@ -784,54 +836,13 @@ def record_packet( skip_duplicates: bool = False, ) -> PacketProtocol: """Record the output packet against the input packet in the result store.""" - - # TODO: consider incorporating execution_engine_opts into the record - data_table = output_packet.as_table(columns={"source": True, "context": True}) - - i = 0 - for k, v in self.get_function_variation_data().items(): - # add the tiered pod ID to the data table - data_table = data_table.add_column( - i, - f"{constants.PF_VARIATION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - for k, v in self.get_execution_data().items(): - # add the tiered pod ID to the data table - data_table = data_table.add_column( - i, - f"{constants.PF_EXECUTION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - # add the input packet hash as a column - data_table = data_table.add_column( - 0, - constants.INPUT_PACKET_HASH_COL, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # add computation timestamp - timestamp = datetime.now(timezone.utc) - data_table = data_table.append_column( - constants.POD_TIMESTAMP, - pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), - ) - - self._result_database.add_record( - self.record_path, - output_packet.datagram_id, # output packet datagram ID (uuid) is used as a unique identification - data_table, + self._cache.store( + input_packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), skip_duplicates=skip_duplicates, ) - - if self._auto_flush: - self._result_database.flush() - - # TODO: make store return retrieved table return output_packet def get_all_cached_outputs( @@ -846,13 +857,6 @@ def get_all_cached_outputs( Returns: A PyArrow table of cached results, or ``None`` if empty. """ - record_id_column = ( - constants.PACKET_RECORD_ID if include_system_columns else None + return self._cache.get_all_records( + include_system_columns=include_system_columns ) - result_table = self._result_database.get_all_records( - self.record_path, record_id_column=record_id_column - ) - if result_table is None or result_table.num_rows == 0: - return None - - return result_table diff --git a/src/orcapod/core/result_cache.py b/src/orcapod/core/result_cache.py new file mode 100644 index 00000000..8afa8b8f --- /dev/null +++ b/src/orcapod/core/result_cache.py @@ -0,0 +1,226 @@ +"""ResultCache — shared result caching logic for CachedPacketFunction and CachedFunctionPod. + +Owns the database, record path, lookup (with match strategy), store, +conflict resolution, and auto-flush behavior. Both ``CachedPacketFunction`` +and ``CachedFunctionPod`` delegate to a ``ResultCache`` instance. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from orcapod.protocols.core_protocols import PacketProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class ResultCache: + """Shared result caching backed by an ``ArrowDatabaseProtocol``. + + Provides lookup (by input packet hash + optional additional constraints), + store (output data + variation/execution metadata + timestamp), conflict + resolution (most-recent-timestamp wins), and auto-flush. + + The match strategy is extensible: the default lookup matches on + ``INPUT_PACKET_HASH_COL`` only, but callers can supply additional + constraints (e.g. function variation columns) to narrow the match. + This is the hook for future match-tier support (see DESIGN_ISSUES P6). + + Args: + result_database: The database to store/retrieve cached results. + record_path: The record path tuple for scoping records in the database. + auto_flush: If True, flush the database after each store operation. + """ + + # Meta column indicating whether the result was freshly computed + RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + + def __init__( + self, + result_database: ArrowDatabaseProtocol, + record_path: tuple[str, ...], + auto_flush: bool = True, + ) -> None: + self._result_database = result_database + self._record_path = record_path + self._auto_flush = auto_flush + + @property + def result_database(self) -> ArrowDatabaseProtocol: + """The underlying database.""" + return self._result_database + + @property + def record_path(self) -> tuple[str, ...]: + """The record path for scoping records in the database.""" + return self._record_path + + def set_auto_flush(self, on: bool = True) -> None: + """Set auto-flush behavior.""" + self._auto_flush = on + + def lookup( + self, + input_packet: PacketProtocol, + additional_constraints: dict[str, str] | None = None, + ) -> PacketProtocol | None: + """Look up a cached output packet for *input_packet*. + + The default match is by ``INPUT_PACKET_HASH_COL`` only. + *additional_constraints* can narrow the match further (e.g. by + function variation hash for stricter cache invalidation). + + If multiple records match, the most recent (by timestamp) wins. + + Args: + input_packet: The input packet whose content hash is the + primary lookup key. + additional_constraints: Optional extra column-value pairs to + include in the lookup query. + + Returns: + The cached output packet with ``RESULT_COMPUTED_FLAG: False`` + in its meta, or ``None`` if no match was found. + """ + from orcapod.core.datagrams import Packet + + RECORD_ID_COL = "_record_id" + + constraints: dict[str, str] = { + constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string(), + } + if additional_constraints: + constraints.update(additional_constraints) + + result_table = self._result_database.get_records_with_column_value( + self._record_path, + constraints, + record_id_column=RECORD_ID_COL, + ) + + if result_table is None or result_table.num_rows == 0: + return None + + if result_table.num_rows > 1: + logger.info( + "Cache conflict resolution: %d records for constraints %s, " + "taking most recent", + result_table.num_rows, + list(constraints.keys()), + ) + result_table = result_table.sort_by( + [(constants.POD_TIMESTAMP, "descending")] + ).take([0]) + + record_id = result_table.to_pylist()[0][RECORD_ID_COL] + # Drop lookup columns from the returned packet + drop_cols = [RECORD_ID_COL] + [ + c for c in constraints if c in result_table.column_names + ] + result_table = result_table.drop_columns(drop_cols) + + return Packet( + result_table, + record_id=record_id, + meta_info={self.RESULT_COMPUTED_FLAG: False}, + ) + + def store( + self, + input_packet: PacketProtocol, + output_packet: PacketProtocol, + variation_data: dict[str, Any], + execution_data: dict[str, Any], + skip_duplicates: bool = False, + ) -> None: + """Store an output packet in the cache. + + Stores the output packet data alongside function variation data, + execution data, input packet hash, and a timestamp. + + Args: + input_packet: The input packet (used for its content hash). + output_packet: The computed output packet to store. + variation_data: Function variation metadata (e.g. function name, + signature hash, content hash, git hash). + execution_data: Execution environment metadata (e.g. python + version, execution context). + skip_duplicates: If True, silently skip if a record with the + same ID already exists. + """ + data_table = output_packet.as_table(columns={"source": True, "context": True}) + + # Add function variation data columns + i = 0 + for k, v in variation_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_VARIATION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add execution data columns + for k, v in execution_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_EXECUTION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add input packet hash (position 0) + data_table = data_table.add_column( + 0, + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + + # Append timestamp + timestamp = datetime.now(timezone.utc) + data_table = data_table.append_column( + constants.POD_TIMESTAMP, + pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), + ) + + self._result_database.add_record( + self._record_path, + output_packet.datagram_id, + data_table, + skip_duplicates=skip_duplicates, + ) + + if self._auto_flush: + self._result_database.flush() + + def get_all_records( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store. + + Args: + include_system_columns: If True, include system columns + (e.g. record_id) in the result. + + Returns: + A PyArrow table of cached results, or ``None`` if empty. + """ + record_id_column = ( + constants.PACKET_RECORD_ID if include_system_columns else None + ) + result_table = self._result_database.get_all_records( + self._record_path, record_id_column=record_id_column + ) + if result_table is None or result_table.num_rows == 0: + return None + return result_table diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index c0bf574d..a800b7ed 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -375,11 +375,10 @@ def run( execution_engine: Optional packet-function executor applied to every function node before execution (e.g. a ``RayExecutor``). Overrides ``config.execution_engine`` when both are provided. - execution_engine_opts: Default resource/options dict forwarded to - the engine for every node (e.g. ``{"num_cpus": 4}``). - Individual nodes may override via their - ``execution_engine_opts`` attribute. Overrides - ``config.execution_engine_opts`` when both are provided. + execution_engine_opts: Resource/options dict forwarded to the + engine via ``with_options()`` (e.g. ``{"num_cpus": 4}``). + Overrides ``config.execution_engine_opts`` when both are + provided. """ from orcapod.types import ExecutorType, PipelineConfig @@ -426,38 +425,34 @@ def _apply_execution_engine( ) -> None: """Apply *execution_engine* to every ``FunctionNode`` in the pipeline. - For each function node, the pipeline-level *execution_engine_opts* are - merged with any per-node ``execution_engine_opts`` override (node opts - win). If the merged opts dict is non-empty, ``engine.with_options`` - is called to produce a node-specific executor; otherwise the engine - instance is used directly. + Each node receives its own executor instance via + ``engine.with_options(**opts)`` — even when *opts* is empty. + The executor's ``with_options`` implementation decides which + components to copy vs share (e.g. connection handles may be + shared while per-node state is copied). Args: execution_engine: Executor to apply (must implement ``PacketFunctionExecutorBase`` or at minimum expose ``with_options``). - execution_engine_opts: Pipeline-level default options dict, or + execution_engine_opts: Pipeline-level options dict, or ``None`` for no defaults. """ assert self._node_graph is not None, ( "_apply_execution_engine called before compile()" ) - pipeline_opts = execution_engine_opts or {} + opts = execution_engine_opts or {} for node in self._node_graph.nodes: if not isinstance(node, FunctionNode): continue - node_opts = node.execution_engine_opts or {} - merged = {**pipeline_opts, **node_opts} - node.executor = ( - execution_engine.with_options(**merged) if merged else execution_engine - ) + node.executor = execution_engine.with_options(**opts) logger.debug( "Applied execution engine %r to node %r (opts=%r)", type(execution_engine).__name__, node.label, - merged or None, + opts or None, ) def _run_async(self, config: PipelineConfig) -> None: @@ -599,8 +594,7 @@ def _build_function_descriptor(self, node: "FunctionNode") -> dict[str, Any]: return { "function_pod": node._function_pod.to_config(), "pipeline_path": list(node.pipeline_path), - "result_record_path": list(node._packet_function.record_path), - "execution_engine_opts": node.execution_engine_opts, + "result_record_path": list(node._cached_function_pod.record_path), } def _build_operator_descriptor(self, node: OperatorNode) -> dict[str, Any]: diff --git a/src/orcapod/protocols/core_protocols/__init__.py b/src/orcapod/protocols/core_protocols/__init__.py index 76c2720a..1f274033 100644 --- a/src/orcapod/protocols/core_protocols/__init__.py +++ b/src/orcapod/protocols/core_protocols/__init__.py @@ -3,7 +3,7 @@ from .async_executable import AsyncExecutableProtocol from .datagrams import DatagramProtocol, PacketProtocol, TagProtocol -from .executor import PacketFunctionExecutorProtocol +from .executor import PacketFunctionExecutorProtocol, PythonFunctionExecutorProtocol from .function_pod import FunctionPodProtocol from .operator_pod import OperatorPodProtocol from .packet_function import PacketFunctionProtocol @@ -27,6 +27,7 @@ "OperatorPodProtocol", "PacketFunctionProtocol", "PacketFunctionExecutorProtocol", + "PythonFunctionExecutorProtocol", "TrackerProtocol", "TrackerManagerProtocol", ] diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index 81a88db4..2260898a 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable from orcapod.protocols.core_protocols.datagrams import PacketProtocol @@ -63,11 +64,13 @@ def supports_concurrent_execution(self) -> bool: ... def with_options(self, **opts: Any) -> Self: - """Return an executor configured with the given per-node options. + """Return a **new** executor instance configured with the given per-node options. Used by the pipeline to produce node-specific executor instances (e.g. with different CPU/GPU allocations) from a shared base executor. - Implementations that do not support options may return *self*. + Implementations must always return a new instance, even when no + options change, so that executors are effectively immutable value + objects after construction. """ ... @@ -78,3 +81,51 @@ def get_execution_data(self) -> dict[str, Any]: affect content or pipeline hashes. """ ... + + +@runtime_checkable +class PythonFunctionExecutorProtocol(PacketFunctionExecutorProtocol, Protocol): + """Executor protocol for Python callable-based packet functions. + + Extends ``PacketFunctionExecutorProtocol`` with callable-level + execution methods. The executor receives the raw Python function + and its keyword arguments directly — the packet function handles + packet construction/deconstruction around the executor call. + """ + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Synchronously execute *fn* with *kwargs*. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options (e.g. resource + overrides). + + Returns: + The raw return value of *fn*. + """ + ... + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Asynchronously execute *fn* with *kwargs*. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + The raw return value of *fn*. + """ + ... diff --git a/src/orcapod/types.py b/src/orcapod/types.py index a7ded54e..3f8938de 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -275,9 +275,8 @@ class PipelineConfig: execution_engine: Optional packet-function executor applied to all function nodes (e.g. ``RayExecutor``). ``None`` means in-process execution. - execution_engine_opts: Default resource/options dict forwarded to the - engine for every node (e.g. ``{"num_cpus": 4}``). Individual - nodes may override via their ``execution_engine_opts`` attribute. + execution_engine_opts: Resource/options dict forwarded to the engine + via ``with_options()`` (e.g. ``{"num_cpus": 4}``). """ executor: ExecutorType = ExecutorType.SYNCHRONOUS diff --git a/tests/test_core/function_pod/test_cached_function_pod.py b/tests/test_core/function_pod/test_cached_function_pod.py new file mode 100644 index 00000000..620a63ca --- /dev/null +++ b/tests/test_core/function_pod/test_cached_function_pod.py @@ -0,0 +1,302 @@ +"""Tests for CachedFunctionPod — pod-level caching wrapper.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.cached_function_pod import CachedFunctionPod +from orcapod.core.function_pod import FunctionPod, function_pod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_stream(rows: list[dict] | None = None) -> ArrowTableStream: + if rows is None: + rows = [{"id": 0, "x": 10}, {"id": 1, "x": 20}] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cache_db(): + return InMemoryArrowDatabase() + + +@pytest.fixture +def double_pod(): + pf = PythonPacketFunction(double, output_keys="result") + return FunctionPod(pf) + + +@pytest.fixture +def cached_pod(double_pod, cache_db): + return CachedFunctionPod( + function_pod=double_pod, + result_database=cache_db, + ) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_record_path_ends_with_inner_uri(self, cached_pod, double_pod): + assert cached_pod.record_path[-len(double_pod.uri) :] == double_pod.uri + + def test_record_path_prefix_empty_by_default(self, cached_pod, double_pod): + assert cached_pod.record_path == double_pod.uri + + def test_record_path_prefix_prepended(self, double_pod, cache_db): + pod = CachedFunctionPod( + double_pod, + result_database=cache_db, + record_path_prefix=("my", "prefix"), + ) + assert pod.record_path[:2] == ("my", "prefix") + + +# --------------------------------------------------------------------------- +# Cache miss +# --------------------------------------------------------------------------- + + +class TestCacheMiss: + def test_returns_non_none_result(self, cached_pod): + stream = _make_stream() + output = cached_pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 2 + + def test_result_values_correct(self, cached_pod): + stream = _make_stream() + output = cached_pod.process(stream) + results = list(output.iter_packets()) + assert results[0][1].as_dict()["result"] == 20 + assert results[1][1].as_dict()["result"] == 40 + + def test_result_stored_in_database(self, cached_pod, cache_db): + stream = _make_stream() + output = cached_pod.process(stream) + list(output.iter_packets()) # exhaust the iterator + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 2 + + +# --------------------------------------------------------------------------- +# Cache hit +# --------------------------------------------------------------------------- + + +class TestCacheHit: + def test_second_call_returns_same_result(self, cached_pod): + stream = _make_stream() + + # First call: cache miss + output1 = cached_pod.process(stream) + first = [(t.as_dict(), p.as_dict()) for t, p in output1.iter_packets()] + + # Second call: cache hit + output2 = cached_pod.process(_make_stream()) + second = [(t.as_dict(), p.as_dict()) for t, p in output2.iter_packets()] + + assert len(first) == len(second) + for (_, p1), (_, p2) in zip(first, second): + assert p1["result"] == p2["result"] + + def test_second_call_does_not_add_new_records(self, cached_pod, cache_db): + stream = _make_stream() + output = cached_pod.process(stream) + list(output.iter_packets()) + + records_after_first = cache_db.get_all_records(cached_pod.record_path) + assert records_after_first is not None + count_first = records_after_first.num_rows + + output2 = cached_pod.process(_make_stream()) + list(output2.iter_packets()) + + records_after_second = cache_db.get_all_records(cached_pod.record_path) + assert records_after_second is not None + assert records_after_second.num_rows == count_first + + +# --------------------------------------------------------------------------- +# Tag-aware caching +# --------------------------------------------------------------------------- + + +class TestCacheKeySemantics: + def test_same_packet_different_tags_is_cache_hit(self, double_pod, cache_db): + """Same packet data with different tags is a cache hit — the function + output depends only on the packet, not the tag.""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + # Same packet data, different tag — should be cache hit + stream2 = _make_stream([{"id": 1, "x": 10}]) + results = list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + # Only 1 record since the cache key is input packet hash only + assert records.num_rows == 1 + # But result is still correct + assert results[0][1].as_dict()["result"] == 20 + + def test_different_packet_data_cached_separately(self, double_pod, cache_db): + """Different packet data should produce separate cache entries.""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + stream2 = _make_stream([{"id": 0, "x": 99}]) + list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 2 + + def test_identical_input_is_cache_hit(self, double_pod, cache_db): + """Exact same input is a cache hit (no new record).""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + stream2 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 1 + + +# --------------------------------------------------------------------------- +# Decorator integration +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Inner function returns None (inactive / filtering) +# --------------------------------------------------------------------------- + + +class TestInnerReturnsNone: + def test_inactive_function_does_not_store(self, cache_db): + pf = PythonPacketFunction(double, output_keys="result") + pf.set_active(False) + pod = FunctionPod(pf) + cached_pod = CachedFunctionPod(pod, result_database=cache_db) + + stream = _make_stream([{"id": 0, "x": 10}]) + results = list(cached_pod.process(stream).iter_packets()) + assert len(results) == 0 + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is None + + +# --------------------------------------------------------------------------- +# Output schema delegation +# --------------------------------------------------------------------------- + + +class TestOutputSchema: + def test_output_schema_delegates_to_inner_pod(self, cached_pod, double_pod): + stream = _make_stream() + expected = double_pod.output_schema(stream) + actual = cached_pod.output_schema(stream) + assert actual == expected + + +# --------------------------------------------------------------------------- +# Dual caching (CachedPacketFunction + CachedFunctionPod) +# --------------------------------------------------------------------------- + + +class TestDualCaching: + def test_both_result_database_and_pod_cache_database(self): + """Decorator with both caching layers produces correct results.""" + pkt_db = InMemoryArrowDatabase() + pod_db = InMemoryArrowDatabase() + + @function_pod( + output_keys="result", + result_database=pkt_db, + pod_cache_database=pod_db, + ) + def my_double(x: int) -> int: + return x * 2 + + stream = _make_stream() + output = my_double.pod.process(stream) + results = list(output.iter_packets()) + + assert len(results) == 2 + assert results[0][1].as_dict()["result"] == 20 + + # Pod-level cache should have entries + pod_records = pod_db.get_all_records(my_double.pod.record_path) + assert pod_records is not None + assert pod_records.num_rows == 2 + + +# --------------------------------------------------------------------------- +# Decorator integration +# --------------------------------------------------------------------------- + + +class TestDecoratorIntegration: + def test_decorator_with_pod_cache_database(self): + db = InMemoryArrowDatabase() + + @function_pod(output_keys="result", pod_cache_database=db) + def my_double(x: int) -> int: + return x * 2 + + assert isinstance(my_double.pod, CachedFunctionPod) + + def test_decorator_pod_cache_produces_correct_results(self): + db = InMemoryArrowDatabase() + + @function_pod(output_keys="result", pod_cache_database=db) + def my_double(x: int) -> int: + return x * 2 + + stream = _make_stream() + output = my_double.pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 2 + assert results[0][1].as_dict()["result"] == 20 + + def test_decorator_without_pod_cache_returns_plain_pod(self): + @function_pod(output_keys="result") + def my_double(x: int) -> int: + return x * 2 + + assert isinstance(my_double.pod, FunctionPod) diff --git a/tests/test_core/function_pod/test_function_node_attach_db.py b/tests/test_core/function_pod/test_function_node_attach_db.py index 2cd6a955..cef00ef1 100644 --- a/tests/test_core/function_pod/test_function_node_attach_db.py +++ b/tests/test_core/function_pod/test_function_node_attach_db.py @@ -59,13 +59,13 @@ def test_attach_databases_sets_pipeline_db(self): node.attach_databases(pipeline_database=db, result_database=db) assert node._pipeline_database is db - def test_attach_databases_wraps_packet_function(self): - from orcapod.core.packet_function import CachedPacketFunction + def test_attach_databases_creates_cached_function_pod(self): + from orcapod.core.cached_function_pod import CachedFunctionPod node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) + assert isinstance(node._cached_function_pod, CachedFunctionPod) def test_attach_databases_clears_caches(self): node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) @@ -83,17 +83,17 @@ def test_attach_databases_computes_pipeline_path(self): assert len(node.pipeline_path) > 0 def test_double_attach_does_not_double_wrap(self): - from orcapod.core.packet_function import CachedPacketFunction + from orcapod.core.cached_function_pod import CachedFunctionPod node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) - # Second attach should not double-wrap + assert isinstance(node._cached_function_pod, CachedFunctionPod) + # Second attach wraps the original function_pod, not the cached one node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) + assert isinstance(node._cached_function_pod, CachedFunctionPod) assert not isinstance( - node._packet_function._packet_function, CachedPacketFunction + node._cached_function_pod._function_pod, CachedFunctionPod ) def test_iter_packets_after_attach_works(self): diff --git a/tests/test_core/function_pod/test_function_node_caching.py b/tests/test_core/function_pod/test_function_node_caching.py new file mode 100644 index 00000000..0a0403f3 --- /dev/null +++ b/tests/test_core/function_pod/test_function_node_caching.py @@ -0,0 +1,391 @@ +"""Tests for FunctionNode caching: pipeline DB vs result DB interaction. + +Covers: +- compute_pipeline_entry_id behavior +- Pipeline entry_id based Phase 2 skip (tag + system_tags + packet_hash) +- CachedFunctionPod result cache hit with novel pipeline entry_id +- Same packet data, different tags → 1 result record, N pipeline records +- System tag awareness in pipeline entry_id computation +- Phase 1 yields existing records, Phase 2 processes only novel entry_ids +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.cached_function_pod import CachedFunctionPod +from orcapod.core.datagrams import Packet, Tag +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.system_constants import constants + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_pod(): + pf = PythonPacketFunction(double, output_keys="result") + return FunctionPod(pf) + + +def _make_stream( + rows: list[dict], tag_columns: list[str] | None = None +) -> ArrowTableStream: + if tag_columns is None: + tag_columns = ["id"] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + return ArrowTableStream(table, tag_columns=tag_columns) + + +def _make_source_stream( + rows: list[dict], tag_columns: list[str] | None = None, source_id: str = "src_a" +) -> ArrowTableStream: + """Create a stream from an ArrowTableSource so it has system tag columns.""" + if tag_columns is None: + tag_columns = ["id"] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + source = ArrowTableSource(table, tag_columns=tag_columns, source_id=source_id) + return source + + +def _make_node(stream, db=None): + pod = _make_pod() + if db is None: + db = InMemoryArrowDatabase() + return FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=db, + result_database=db, + ), db + + +# --------------------------------------------------------------------------- +# compute_pipeline_entry_id +# --------------------------------------------------------------------------- + + +class TestComputePipelineEntryId: + def test_returns_non_empty_string(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + packet = Packet({"x": 10}) + entry_id = node.compute_pipeline_entry_id(tag, packet) + assert isinstance(entry_id, str) + assert len(entry_id) > 0 + + def test_same_inputs_produce_same_id(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + packet = Packet({"x": 10}) + id1 = node.compute_pipeline_entry_id(tag, packet) + id2 = node.compute_pipeline_entry_id(tag, packet) + assert id1 == id2 + + def test_different_tags_produce_different_ids(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + packet = Packet({"x": 10}) + id_tag0 = node.compute_pipeline_entry_id(Tag({"id": 0}), packet) + id_tag1 = node.compute_pipeline_entry_id(Tag({"id": 1}), packet) + assert id_tag0 != id_tag1 + + def test_different_packets_produce_different_ids(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + id_x10 = node.compute_pipeline_entry_id(tag, Packet({"x": 10})) + id_x99 = node.compute_pipeline_entry_id(tag, Packet({"x": 99})) + assert id_x10 != id_x99 + + +# --------------------------------------------------------------------------- +# System tag awareness in entry_id +# --------------------------------------------------------------------------- + + +class TestSystemTagAwareness: + def test_same_tag_values_different_system_tags_produce_different_ids(self): + """Two tags with identical user values but different system tags + must produce different pipeline entry_ids.""" + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + packet = Packet({"x": 10}) + + # Tags with same user value but different system tag columns + tag_a = Tag( + {"id": 0}, + system_tags={f"{constants.SYSTEM_TAG_PREFIX}source:abc": "row0"}, + ) + tag_b = Tag( + {"id": 0}, + system_tags={f"{constants.SYSTEM_TAG_PREFIX}source:xyz": "row0"}, + ) + + id_a = node.compute_pipeline_entry_id(tag_a, packet) + id_b = node.compute_pipeline_entry_id(tag_b, packet) + assert id_a != id_b + + +# --------------------------------------------------------------------------- +# Result DB vs Pipeline DB record counts +# --------------------------------------------------------------------------- + + +class TestResultVsPipelineRecordCounts: + def test_same_packet_different_tags_one_result_two_pipeline_records(self): + """Same packet data with different tags should produce: + - 1 result record (CachedFunctionPod caches by packet hash only) + - 2 pipeline records (different tag → different entry_id) + """ + rows = [{"id": 0, "x": 10}, {"id": 1, "x": 10}] + stream = _make_stream(rows) + node, db = _make_node(stream) + node.run() + + # Result DB: 1 record (same packet hash, second is cache hit) + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + # Pipeline DB: 2 records (different tags → different entry_ids) + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + def test_different_packets_same_tag_two_result_two_pipeline_records(self): + """Different packet data with same tag should produce: + - 2 result records (different packet hashes) + - 2 pipeline records (different packet hash → different entry_id) + """ + rows = [{"id": 0, "x": 10}, {"id": 0, "x": 20}] + stream = _make_stream(rows) + node, db = _make_node(stream) + node.run() + + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 2 + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + def test_identical_rows_one_result_one_pipeline_record(self): + """Identical (tag, packet) → 1 result record, 1 pipeline record.""" + # A single row — process once + stream = _make_stream([{"id": 0, "x": 10}]) + node, db = _make_node(stream) + node.run() + + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 1 + + +# --------------------------------------------------------------------------- +# Phase 1/2 interaction with pipeline entry_ids +# --------------------------------------------------------------------------- + + +class TestPhase1Phase2PipelineEntryId: + def test_phase1_yields_existing_results(self): + """Phase 1 should yield all previously computed results from DB.""" + db = InMemoryArrowDatabase() + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Create a second node with the same stream and shared DB + stream2 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # All results should come from Phase 1 (DB), not recomputed + assert len(results) == 2 + + def test_phase2_processes_novel_entry_ids_only(self): + """Phase 2 should only process inputs whose pipeline entry_id + is not yet in the pipeline DB.""" + db = InMemoryArrowDatabase() + + # First run: process 2 rows + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Second run: 3 rows, 2 existing + 1 new + stream2 = _make_stream( + [{"id": 0, "x": 10}, {"id": 1, "x": 20}, {"id": 2, "x": 30}] + ) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Should yield 3 total: 2 from Phase 1 + 1 from Phase 2 + assert len(results) == 3 + # The new result should have the correct value + result_values = sorted(p.as_dict()["result"] for _, p in results) + assert result_values == [20, 40, 60] + + def test_same_packet_new_tag_triggers_phase2(self): + """Same packet data but new tag should trigger Phase 2 processing + because the pipeline entry_id (tag+system_tags+packet) is novel, + even though CachedFunctionPod has a cache hit for the packet.""" + db = InMemoryArrowDatabase() + + # First run: tag=0, x=10 + stream1 = _make_stream([{"id": 0, "x": 10}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + pipeline_count_after_first = db.get_all_records(node1.pipeline_path).num_rows + assert pipeline_count_after_first == 1 + + # Second run: tag=1, x=10 (same packet, different tag) + stream2 = _make_stream([{"id": 1, "x": 10}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Should yield 1 result from Phase 1 (tag=0) + 1 from Phase 2 (tag=1) + assert len(results) == 2 + + # Pipeline DB should now have 2 records + pipeline_records = db.get_all_records(node2.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + # Result DB should still have only 1 record (same packet hash) + result_records = node2._cached_function_pod._result_database.get_all_records( + node2._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + def test_all_existing_entry_ids_skipped_in_phase2(self): + """When all inputs already have pipeline records, Phase 2 + should not call process_packet at all.""" + db = InMemoryArrowDatabase() + + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Re-run with identical stream + stream2 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node2, _ = _make_node(stream2, db=db) + + pipeline_count_before = db.get_all_records(node2.pipeline_path).num_rows + + results = list(node2.iter_packets()) + assert len(results) == 2 + + # No new pipeline records should be added + pipeline_count_after = db.get_all_records(node2.pipeline_path).num_rows + assert pipeline_count_after == pipeline_count_before + + +# --------------------------------------------------------------------------- +# CachedFunctionPod cache hit + novel pipeline entry +# --------------------------------------------------------------------------- + + +class TestResultCacheHitPipelineNovel: + def test_cached_result_reused_for_new_tag(self): + """When CachedFunctionPod has a cache hit (same packet hash) but + the pipeline entry_id is novel (different tag), the cached result + should be reused and a new pipeline record created.""" + db = InMemoryArrowDatabase() + + # Process tag=0, x=10 + stream1 = _make_stream([{"id": 0, "x": 10}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Process tag=1, x=10 — same packet, different tag + stream2 = _make_stream([{"id": 1, "x": 10}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Both tags should produce the same result value + result_values = [p.as_dict()["result"] for _, p in results] + assert all(v == 20 for v in result_values) + + def test_pipeline_records_reference_same_result_uuid(self): + """Two pipeline records for the same packet (different tags) + should reference the same output packet UUID in the result DB.""" + db = InMemoryArrowDatabase() + + stream = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 10}]) + node, _ = _make_node(stream, db=db) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + # Both pipeline records should reference the same PACKET_RECORD_ID + record_ids = pipeline_records.column(constants.PACKET_RECORD_ID).to_pylist() + assert record_ids[0] == record_ids[1] + + +# --------------------------------------------------------------------------- +# Source columns in pipeline records +# --------------------------------------------------------------------------- + + +class TestPipelineRecordSourceColumns: + def test_pipeline_record_contains_source_columns(self): + """Pipeline records should include source columns of the input packet.""" + stream = _make_source_stream([{"id": 0, "x": 10}], source_id="my_source") + node, db = _make_node(stream) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + + source_cols = [ + c + for c in pipeline_records.column_names + if c.startswith(constants.SOURCE_PREFIX) + ] + assert len(source_cols) > 0 + + def test_pipeline_record_excludes_data_columns_of_input(self): + """Pipeline records should NOT include data columns of the input packet.""" + stream = _make_source_stream([{"id": 0, "x": 10}]) + node, db = _make_node(stream) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + + # "x" is the input packet data column — should not appear + assert "x" not in pipeline_records.column_names + assert "_input_x" not in pipeline_records.column_names diff --git a/tests/test_core/function_pod/test_function_pod_node.py b/tests/test_core/function_pod/test_function_pod_node.py index 00460bba..29bd5514 100644 --- a/tests/test_core/function_pod/test_function_pod_node.py +++ b/tests/test_core/function_pod/test_function_pod_node.py @@ -665,7 +665,7 @@ def test_result_records_stored_under_result_suffix_path(self, double_pf): node.process_packet(tag, packet) db.flush() - result_path = node._packet_function.record_path + result_path = node._cached_function_pod.record_path assert result_path[-1] == "_result" or any( "_result" in part for part in result_path ) diff --git a/tests/test_core/packet_function/test_cached_packet_function.py b/tests/test_core/packet_function/test_cached_packet_function.py index ba968dd0..9a4218a1 100644 --- a/tests/test_core/packet_function/test_cached_packet_function.py +++ b/tests/test_core/packet_function/test_cached_packet_function.py @@ -96,11 +96,11 @@ def test_record_path_prefix_prepended(self, inner_pf, db): assert cpf.record_path == ("org", "project") + inner_pf.uri def test_auto_flush_true_by_default(self, cached_pf): - assert cached_pf._auto_flush is True + assert cached_pf._cache._auto_flush is True def test_set_auto_flush_false(self, cached_pf): cached_pf.set_auto_flush(False) - assert cached_pf._auto_flush is False + assert cached_pf._cache._auto_flush is False # --------------------------------------------------------------------------- diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 0ac1e9f3..317a2cb1 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -26,6 +26,7 @@ PacketFunctionExecutorProtocol, PacketFunctionProtocol, PacketProtocol, + PythonFunctionExecutorProtocol, ) # --------------------------------------------------------------------------- @@ -63,6 +64,10 @@ def execute( self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.calls.append((fn, kwargs)) + return fn(**kwargs) + class PythonOnlyExecutor(PacketFunctionExecutorBase): """Executor that only supports python.function.v0.""" @@ -147,6 +152,19 @@ def test_get_execution_data_returns_type(self): data = executor.get_execution_data() assert data["executor_type"] == "spy" + def test_with_options_returns_new_instance(self): + executor = SpyExecutor() + new_executor = executor.with_options() + assert new_executor is not executor + assert isinstance(new_executor, SpyExecutor) + + def test_with_options_preserves_state(self): + executor = SpyExecutor(supported_types=frozenset({"python.function.v0"})) + new_executor = executor.with_options() + assert new_executor.supported_function_type_ids() == frozenset( + {"python.function.v0"} + ) + # --------------------------------------------------------------------------- # 2. LocalExecutor @@ -175,6 +193,11 @@ def test_get_execution_data(self, local_executor: LocalExecutor): data = local_executor.get_execution_data() assert data["executor_type"] == "local" + def test_with_options_returns_new_instance(self, local_executor: LocalExecutor): + new_executor = local_executor.with_options() + assert new_executor is not local_executor + assert isinstance(new_executor, LocalExecutor) + # --------------------------------------------------------------------------- # 3. Executor as property on PacketFunctionBase @@ -238,7 +261,6 @@ def test_call_with_executor_routes_through_executor( assert result is not None assert result.as_dict()["result"] == 3 assert len(spy_executor.calls) == 1 - assert spy_executor.calls[0][0] is add_pf def test_direct_call_bypasses_executor( self, @@ -559,8 +581,8 @@ class ConcurrentSpyExecutor(PacketFunctionExecutorBase): """Executor that supports concurrent execution and tracks sync vs async calls.""" def __init__(self) -> None: - self.sync_calls: list[PacketProtocol] = [] - self.async_calls: list[PacketProtocol] = [] + self.sync_calls: list[Any] = [] + self.async_calls: list[Any] = [] @property def executor_type_id(self) -> str: @@ -589,6 +611,14 @@ async def async_execute( self.async_calls.append(packet) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.sync_calls.append(kwargs) + return fn(**kwargs) + + async def async_execute_callable(self, fn, kwargs, executor_options=None): + self.async_calls.append(kwargs) + return fn(**kwargs) + class TestConcurrentIteration: def test_function_pod_stream_uses_async_path(self): @@ -688,3 +718,152 @@ def test_second_iteration_uses_cache(self): second = list(output_stream.iter_packets()) assert len(spy.async_calls) == 2 # unchanged assert len(first) == len(second) + + +# --------------------------------------------------------------------------- +# 11. PythonFunctionExecutorProtocol conformance +# --------------------------------------------------------------------------- + + +class TestPythonFunctionExecutorProtocol: + def test_local_executor_satisfies_protocol(self): + executor = LocalExecutor() + assert isinstance(executor, PythonFunctionExecutorProtocol) + + def test_spy_executor_satisfies_protocol(self): + executor = SpyExecutor() + assert isinstance(executor, PythonFunctionExecutorProtocol) + + def test_execute_callable_runs_function(self): + executor = LocalExecutor() + result = executor.execute_callable(add, {"x": 3, "y": 4}) + assert result == 7 + + def test_execute_callable_with_executor_options(self): + executor = LocalExecutor() + result = executor.execute_callable( + add, {"x": 1, "y": 2}, executor_options={"num_cpus": 1} + ) + assert result == 3 + + +# --------------------------------------------------------------------------- +# 12. Type-safe executor dispatch via Generic[E] + __init_subclass__ +# --------------------------------------------------------------------------- + + +class _NotAnExecutor: + """A class that does NOT satisfy PythonFunctionExecutorProtocol.""" + + @property + def executor_type_id(self) -> str: + return "fake" + + def supported_function_type_ids(self) -> frozenset[str]: + return frozenset() + + def supports(self, packet_function_type_id: str) -> bool: + return True + + +class TestGenericExecutorDispatch: + def test_python_pf_resolves_executor_protocol(self): + """PythonPacketFunction should have resolved PythonFunctionExecutorProtocol.""" + assert ( + PythonPacketFunction._resolved_executor_protocol + is PythonFunctionExecutorProtocol + ) + + def test_wrapper_does_not_resolve_protocol(self): + """PacketFunctionWrapper[E] should NOT resolve a protocol (E is unbound).""" + assert PacketFunctionWrapper._resolved_executor_protocol is None + + def test_set_executor_accepts_compatible_protocol(self): + pf = PythonPacketFunction(add, output_keys="result") + executor = LocalExecutor() + pf.set_executor(executor) + assert pf.executor is executor + + def test_set_executor_accepts_none(self): + pf = PythonPacketFunction(add, output_keys="result", executor=LocalExecutor()) + pf.set_executor(None) + assert pf.executor is None + + def test_set_executor_rejects_non_conforming_protocol(self): + """An object that doesn't implement PythonFunctionExecutorProtocol is rejected.""" + pf = PythonPacketFunction(add, output_keys="result") + fake = _NotAnExecutor() + with pytest.raises(TypeError, match="requires an executor implementing"): + pf.set_executor(fake) + + def test_call_routes_through_execute_callable(self): + """PythonPacketFunction.call() should use execute_callable, not execute.""" + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + packet = Packet({"x": 1, "y": 2}) + result = pf.call(packet) + assert result is not None + assert result.as_dict()["result"] == 3 + assert len(spy.calls) == 1 + + def test_call_with_inactive_function_returns_none(self): + """When the function is inactive and executor is set, call returns None.""" + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + pf.set_active(False) + packet = Packet({"x": 1, "y": 2}) + result = pf.call(packet) + assert result is None + assert len(spy.calls) == 0 + + def test_async_call_routes_through_async_execute_callable(self): + """PythonPacketFunction.async_call() should use async_execute_callable.""" + import asyncio + + spy = ConcurrentSpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + packet = Packet({"x": 1, "y": 2}) + result = asyncio.run(pf.async_call(packet)) + assert result is not None + assert result.as_dict()["result"] == 3 + assert len(spy.async_calls) == 1 + assert len(spy.sync_calls) == 0 + + +# --------------------------------------------------------------------------- +# 13. LocalExecutor execute_callable with async function +# --------------------------------------------------------------------------- + + +class TestLocalExecutorCallable: + def test_execute_callable_with_async_fn(self): + """LocalExecutor.execute_callable handles async functions via asyncio.run.""" + import asyncio + + async def async_add(x: int, y: int) -> int: + return x + y + + executor = LocalExecutor() + result = executor.execute_callable(async_add, {"x": 5, "y": 3}) + assert result == 8 + + def test_async_execute_callable_with_sync_fn(self): + """LocalExecutor.async_execute_callable handles sync fns via run_in_executor.""" + import asyncio + + executor = LocalExecutor() + result = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) + assert result == 30 + + def test_async_execute_callable_with_async_fn(self): + """LocalExecutor.async_execute_callable awaits async functions directly.""" + import asyncio + + async def async_add(x: int, y: int) -> int: + return x + y + + executor = LocalExecutor() + result = asyncio.run( + executor.async_execute_callable(async_add, {"x": 7, "y": 8}) + ) + assert result == 15 diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index 809b5a7a..793f0efe 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -82,6 +82,10 @@ def execute( self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.calls.append((fn, kwargs)) + return fn(**kwargs) + # =========================================================================== # 1. async_execute output channel closed on exception (try/finally) diff --git a/tests/test_core/test_result_cache.py b/tests/test_core/test_result_cache.py new file mode 100644 index 00000000..cda78270 --- /dev/null +++ b/tests/test_core/test_result_cache.py @@ -0,0 +1,306 @@ +"""Tests for ResultCache — shared result caching logic. + +Covers: +- Lookup: cache miss, cache hit, conflict resolution (most recent wins) +- Store: variation/execution columns, input packet hash, timestamp +- additional_constraints for narrowing match +- auto_flush behavior +- get_all_records +""" + +from __future__ import annotations + +from typing import Any + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams import Packet +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.result_cache import ResultCache +from orcapod.databases import InMemoryArrowDatabase +from orcapod.system_constants import constants + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_pf() -> PythonPacketFunction: + return PythonPacketFunction(double, output_keys="result") + + +def _make_cache( + db=None, record_path=("test",) +) -> tuple[ResultCache, InMemoryArrowDatabase]: + if db is None: + db = InMemoryArrowDatabase() + cache = ResultCache(result_database=db, record_path=record_path) + return cache, db + + +def _compute_and_store( + cache: ResultCache, pf: PythonPacketFunction, input_packet: Packet +): + """Helper: compute output and store in cache.""" + output = pf.direct_call(input_packet) + assert output is not None + cache.store( + input_packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + return output + + +# --------------------------------------------------------------------------- +# Lookup +# --------------------------------------------------------------------------- + + +class TestLookupMiss: + def test_empty_db_returns_none(self): + cache, _ = _make_cache() + packet = Packet({"x": 10}) + assert cache.lookup(packet) is None + + def test_different_packet_returns_none(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + assert cache.lookup(Packet({"x": 99})) is None + + +class TestLookupHit: + def test_returns_cached_result(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.as_dict()["result"] == 20 + + def test_sets_computed_flag_false(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.get_meta_value(ResultCache.RESULT_COMPUTED_FLAG) is False + + def test_same_packet_different_record_path_is_miss(self): + db = InMemoryArrowDatabase() + cache_a = ResultCache(result_database=db, record_path=("path_a",)) + cache_b = ResultCache(result_database=db, record_path=("path_b",)) + pf = _make_pf() + input_pkt = Packet({"x": 10}) + + output = pf.direct_call(input_pkt) + cache_a.store( + input_pkt, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + + assert cache_a.lookup(input_pkt) is not None + assert cache_b.lookup(input_pkt) is None + + +class TestConflictResolution: + def test_most_recent_wins(self): + import time + + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + + # Store first result + _compute_and_store(cache, pf, input_pkt) + time.sleep(0.01) # ensure different timestamp + + # Store a second result for the same input (simulating recomputation) + output2 = pf.direct_call(input_pkt) + cache.store( + input_pkt, + output2, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + + # Lookup should return the most recent + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.datagram_id == output2.datagram_id + + +# --------------------------------------------------------------------------- +# Additional constraints +# --------------------------------------------------------------------------- + + +class TestAdditionalConstraints: + def test_narrower_match_filters_results(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + # Lookup with a constraint that doesn't match any stored column value + result = cache.lookup( + input_pkt, + additional_constraints={ + f"{constants.PF_VARIATION_PREFIX}function_name": "nonexistent" + }, + ) + assert result is None + + def test_matching_constraint_returns_result(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + # Lookup with a constraint that matches + result = cache.lookup( + input_pkt, + additional_constraints={ + f"{constants.PF_VARIATION_PREFIX}function_name": "double" + }, + ) + assert result is not None + assert result.as_dict()["result"] == 20 + + +# --------------------------------------------------------------------------- +# Store +# --------------------------------------------------------------------------- + + +class TestStore: + def test_stores_input_packet_hash_column(self): + cache, db = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert constants.INPUT_PACKET_HASH_COL in records.column_names + + def test_stores_variation_columns(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + variation_cols = [ + c + for c in records.column_names + if c.startswith(constants.PF_VARIATION_PREFIX) + ] + assert len(variation_cols) > 0 + + def test_stores_execution_columns(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + exec_cols = [ + c + for c in records.column_names + if c.startswith(constants.PF_EXECUTION_PREFIX) + ] + assert len(exec_cols) > 0 + + def test_stores_timestamp(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert constants.POD_TIMESTAMP in records.column_names + + def test_stores_output_data(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert "result" in records.column_names + assert records.column("result").to_pylist() == [20] + + +# --------------------------------------------------------------------------- +# Auto flush +# --------------------------------------------------------------------------- + + +class TestAutoFlush: + def test_auto_flush_true_by_default(self): + cache, _ = _make_cache() + assert cache._auto_flush is True + + def test_set_auto_flush_false(self): + cache, _ = _make_cache() + cache.set_auto_flush(False) + assert cache._auto_flush is False + + def test_auto_flush_false_in_constructor(self): + db = InMemoryArrowDatabase() + cache = ResultCache(result_database=db, record_path=("t",), auto_flush=False) + assert cache._auto_flush is False + + +# --------------------------------------------------------------------------- +# get_all_records +# --------------------------------------------------------------------------- + + +class TestGetAllRecords: + def test_empty_returns_none(self): + cache, _ = _make_cache() + assert cache.get_all_records() is None + + def test_returns_stored_records(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + _compute_and_store(cache, pf, Packet({"x": 20})) + + records = cache.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_include_system_columns_adds_record_id(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = cache.get_all_records(include_system_columns=True) + assert records is not None + assert constants.PACKET_RECORD_ID in records.column_names + + def test_exclude_system_columns_by_default(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = cache.get_all_records(include_system_columns=False) + assert records is not None + assert constants.PACKET_RECORD_ID not in records.column_names diff --git a/tests/test_pipeline/test_node_descriptors.py b/tests/test_pipeline/test_node_descriptors.py index ed86f18c..94d5c7a0 100644 --- a/tests/test_pipeline/test_node_descriptors.py +++ b/tests/test_pipeline/test_node_descriptors.py @@ -153,8 +153,7 @@ def _make_function_node_descriptor(self): }, "function_pod": pod.to_config(), "pipeline_path": list(node.pipeline_path), - "result_record_path": list(node._packet_function.record_path), - "execution_engine_opts": None, + "result_record_path": list(node._cached_function_pod.record_path), } return node, descriptor, db diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 31af0195..99ffd956 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -264,9 +264,9 @@ def test_function_database_none_uses_results_subfolder(self, pipeline_db): node = pipeline.compiled_nodes["adder"] assert isinstance(node, FunctionNode) - # The CachedPacketFunction's record_path should start with + # The CachedFunctionPod's record_path should start with # (pipeline_name, "_results", ...) - record_path = node._packet_function.record_path + record_path = node._cached_function_pod.record_path assert record_path[0] == "my_pipe" assert record_path[1] == "_results" @@ -289,8 +289,8 @@ def test_separate_function_database(self, pipeline_db, function_db): node = pipeline.compiled_nodes["adder"] assert isinstance(node, FunctionNode) - # The CachedPacketFunction should use function_db - assert node._packet_function._result_database is function_db + # The CachedFunctionPod should use function_db + assert node._cached_function_pod._result_database is function_db # --------------------------------------------------------------------------- @@ -1146,8 +1146,8 @@ class _MockExecutor(PacketFunctionExecutorBase): def __init__(self, opts: dict[str, Any] | None = None) -> None: self.opts: dict[str, Any] = opts or {} - self.sync_calls: list[PacketProtocol] = [] - self.async_calls: list[PacketProtocol] = [] + self.sync_calls: list[Any] = [] + self.async_calls: list[Any] = [] @property def executor_type_id(self) -> str: @@ -1176,6 +1176,14 @@ async def async_execute( self.async_calls.append(packet) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.sync_calls.append(kwargs) + return fn(**kwargs) + + async def async_execute_callable(self, fn, kwargs, executor_options=None): + self.async_calls.append(kwargs) + return fn(**kwargs) + def with_options(self, **opts: Any) -> "_MockExecutor": return _MockExecutor(opts={**self.opts, **opts}) @@ -1196,7 +1204,7 @@ def test_engine_is_applied_to_all_function_nodes(self, pipeline_db): pipeline.run(execution_engine=mock) - assert pipeline.doubler.executor is mock + assert isinstance(pipeline.doubler.executor, _MockExecutor) def test_engine_without_config_triggers_async_mode(self, pipeline_db): """No config + execution_engine → async channels mode by default.""" @@ -1211,8 +1219,10 @@ def test_engine_without_config_triggers_async_mode(self, pipeline_db): pipeline.run(execution_engine=mock) - assert len(mock.async_calls) > 0 - assert len(mock.sync_calls) == 0 + # Each node gets its own copy; check the node's executor + node_executor = pipeline.doubler.executor + assert len(node_executor.async_calls) > 0 + assert len(node_executor.sync_calls) == 0 def test_explicit_sync_config_overrides_async_default(self, pipeline_db): """Explicit config=PipelineConfig(executor=SYNCHRONOUS) takes priority @@ -1233,29 +1243,46 @@ def test_explicit_sync_config_overrides_async_default(self, pipeline_db): config=PipelineConfig(executor=ExecutorType.SYNCHRONOUS), ) - assert len(mock.sync_calls) > 0 - assert len(mock.async_calls) == 0 + node_executor = pipeline.doubler.executor + assert len(node_executor.sync_calls) > 0 + assert len(node_executor.async_calls) == 0 - def test_per_node_opts_override_pipeline_opts(self, pipeline_db): - """Node-level execution_engine_opts win over pipeline-level defaults.""" + def test_pipeline_opts_applied_via_with_options(self, pipeline_db): + """Pipeline-level execution_engine_opts are applied via with_options.""" src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) pf = PythonPacketFunction(double_value, output_keys="result") pod = FunctionPod(packet_function=pf) mock = _MockExecutor() - pipeline = Pipeline(name="test_node_opts", pipeline_database=pipeline_db) + pipeline = Pipeline(name="test_pipeline_opts", pipeline_database=pipeline_db) with pipeline: pod(src, label="doubler") - pipeline.doubler.execution_engine_opts = {"num_cpus": 2} - pipeline.run( execution_engine=mock, - execution_engine_opts={"num_cpus": 1}, + execution_engine_opts={"num_cpus": 4}, ) - # Node opts win: executor should have been created with num_cpus=2 - assert pipeline.doubler.executor.opts.get("num_cpus") == 2 + # Executor should have been created with the pipeline opts + assert pipeline.doubler.executor.opts.get("num_cpus") == 4 + + def test_no_opts_produces_per_node_copy(self, pipeline_db): + """Without execution_engine_opts, each node gets its own executor copy.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(packet_function=pf) + mock = _MockExecutor() + + pipeline = Pipeline(name="test_no_opts", pipeline_database=pipeline_db) + with pipeline: + pod(src, label="doubler") + + pipeline.run(execution_engine=mock) + + # Each node gets a copy via with_options(), not the original + assert pipeline.doubler.executor is not mock + assert isinstance(pipeline.doubler.executor, _MockExecutor) + assert pipeline.doubler.executor.opts == {} class TestSourceNodesInPipeline: