From 3abb1bb090fbc9925332fa89bbce81ddee090f23 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 05:25:28 +0000 Subject: [PATCH 1/8] refactor(executors): improve function execution chain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Make executors immutable: with_options() always returns a new instance (copy.copy for base, new LocalExecutor() for local, RayExecutor already did this) - Remove execution_engine_opts from FunctionNode — pipeline executor assignment logic is now the sole owner of per-node configuration - Add type-safe executor dispatch via Generic[E] + __init_subclass__ on PacketFunctionBase — resolves executor protocol once at class definition time, validates at set_executor() instead of in the hot path - Add PythonFunctionExecutorProtocol with execute_callable/ async_execute_callable — executors receive raw callables + kwargs instead of packet_function + packet objects - PythonPacketFunction now routes call()/async_call() through execute_callable, keeping packet construction in the function - Add CachedFunctionPod — pod-level caching wrapper that intercepts process_packet() with tag+packet content hash as cache key - Add pod_cache_database parameter to function_pod decorator Co-Authored-By: Claude Opus 4.6 (1M context) --- src/orcapod/core/cached_function_pod.py | 192 +++++++++++++++++ src/orcapod/core/executors/base.py | 51 ++++- src/orcapod/core/executors/local.py | 35 ++- src/orcapod/core/executors/ray.py | 29 +++ src/orcapod/core/function_pod.py | 18 +- src/orcapod/core/nodes/function_node.py | 6 - src/orcapod/core/packet_function.py | 104 +++++++-- src/orcapod/pipeline/graph.py | 31 ++- .../protocols/core_protocols/__init__.py | 3 +- .../protocols/core_protocols/executor.py | 74 ++++++- src/orcapod/types.py | 5 +- .../function_pod/test_cached_function_pod.py | 202 ++++++++++++++++++ .../packet_function/test_executor.py | 104 ++++++++- tests/test_core/test_regression_fixes.py | 4 + tests/test_pipeline/test_node_descriptors.py | 1 - tests/test_pipeline/test_pipeline.py | 22 +- 16 files changed, 816 insertions(+), 65 deletions(-) create mode 100644 src/orcapod/core/cached_function_pod.py create mode 100644 tests/test_core/function_pod/test_cached_function_pod.py diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py new file mode 100644 index 00000000..29ec2de2 --- /dev/null +++ b/src/orcapod/core/cached_function_pod.py @@ -0,0 +1,192 @@ +"""CachedFunctionPod — pod-level caching wrapper that intercepts process_packet().""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.core.function_pod import WrappedFunctionPod +from orcapod.protocols.core_protocols import ( + FunctionPodProtocol, + PacketProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class CachedFunctionPod(WrappedFunctionPod): + """Pod-level caching wrapper that intercepts ``process_packet()``. + + Unlike ``CachedPacketFunction`` (which caches at the ``call(packet)`` + level using only the packet content hash as the cache key), this + wrapper operates at the ``process_packet(tag, packet)`` level and + incorporates *both* the tag content hash and the packet content hash + into the cache key. + + This is useful when the same packet data may produce different results + depending on tag metadata, or when tag-level deduplication is desired. + """ + + def __init__( + self, + function_pod: FunctionPodProtocol, + result_database: ArrowDatabaseProtocol, + record_path_prefix: tuple[str, ...] = (), + **kwargs, + ) -> None: + super().__init__(function_pod, **kwargs) + self._result_database = result_database + self._record_path_prefix = record_path_prefix + + @property + def record_path(self) -> tuple[str, ...]: + """Return the path to the cached records in the result store.""" + return self._record_path_prefix + self.uri + + def _compute_cache_key(self, tag: TagProtocol, packet: PacketProtocol) -> str: + """Compute a cache key from tag and packet content hashes. + + Args: + tag: The tag associated with the packet. + packet: The input packet. + + Returns: + A string combining tag and packet content hashes. + """ + tag_hash = tag.content_hash().to_string() + packet_hash = packet.content_hash().to_string() + return f"{tag_hash}:{packet_hash}" + + def _lookup(self, cache_key: str) -> PacketProtocol | None: + """Look up a cached output packet by cache key. + + Args: + cache_key: The combined tag+packet hash key. + + Returns: + The cached output packet, or ``None`` if not found. + """ + from orcapod.core.datagrams import Packet + + CACHE_KEY_COL = f"{constants.META_PREFIX}pod_cache_key" + RECORD_ID_COL = "_record_id" + + result_table = self._result_database.get_records_with_column_value( + self.record_path, + {CACHE_KEY_COL: cache_key}, + record_id_column=RECORD_ID_COL, + ) + + if result_table is None or result_table.num_rows == 0: + return None + + if result_table.num_rows > 1: + logger.info( + "Pod-level cache: multiple records for key %s, taking most recent", + cache_key, + ) + result_table = result_table.sort_by( + [(constants.POD_TIMESTAMP, "descending")] + ).take([0]) + + record_id = result_table.to_pylist()[0][RECORD_ID_COL] + result_table = result_table.drop_columns([RECORD_ID_COL, CACHE_KEY_COL]) + + return Packet(result_table, record_id=record_id) + + def _store( + self, + cache_key: str, + output_packet: PacketProtocol, + ) -> None: + """Store an output packet in the cache. + + Args: + cache_key: The combined tag+packet hash key. + output_packet: The computed output packet to store. + """ + from datetime import datetime, timezone + + CACHE_KEY_COL = f"{constants.META_PREFIX}pod_cache_key" + + data_table = output_packet.as_table(columns={"source": True, "context": True}) + + # Prepend cache key column + data_table = data_table.add_column( + 0, + CACHE_KEY_COL, + pa.array([cache_key], type=pa.large_string()), + ) + + # Append timestamp + timestamp = datetime.now(timezone.utc) + data_table = data_table.append_column( + constants.POD_TIMESTAMP, + pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), + ) + + self._result_database.add_record( + self.record_path, + output_packet.datagram_id, + data_table, + skip_duplicates=False, + ) + self._result_database.flush() + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a packet with pod-level caching. + + The cache key incorporates both tag and packet content hashes. + On a cache hit, the stored output packet is returned without + invoking the inner pod's computation. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple; output_packet is ``None`` + if the inner function filters the packet out. + """ + cache_key = self._compute_cache_key(tag, packet) + cached = self._lookup(cache_key) + if cached is not None: + logger.info("Pod-level cache hit for key %s", cache_key) + return tag, cached + + tag, output = self._function_pod.process_packet(tag, packet) + if output is not None: + self._store(cache_key, output) + return tag, output + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """Invoke the inner pod but with pod-level caching on process_packet. + + The stream returned uses *this* pod's ``process_packet`` (which + includes caching) rather than the inner pod's. + """ + from orcapod.core.function_pod import FunctionPod, FunctionPodStream + + # Validate and prepare the input stream + input_stream = self._function_pod.handle_input_streams(*streams) + self._function_pod.validate_inputs(*streams) + + return FunctionPodStream( + function_pod=self, + input_stream=input_stream, + label=label, + ) diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py index 8bb2fd23..386cbe1a 100644 --- a/src/orcapod/core/executors/base.py +++ b/src/orcapod/core/executors/base.py @@ -1,6 +1,8 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod +from collections.abc import Callable from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -77,13 +79,52 @@ def supports_concurrent_execution(self) -> bool: return False def with_options(self, **opts: Any) -> "PacketFunctionExecutorBase": - """Return an executor configured with the given per-node options. + """Return a **new** executor instance configured with the given per-node options. - The default implementation ignores *opts* and returns *self*. - Subclasses that support resource options (e.g. ``RayExecutor``) - should override to return a new instance with the merged options. + The default implementation returns a shallow copy of *self*. + Subclasses that carry mutable state (e.g. ``RayExecutor``) should + override to produce a properly configured new instance. """ - return self + return copy.copy(self) + + # ------------------------------------------------------------------ + # Callable-level execution (PythonFunctionExecutorProtocol) + # ------------------------------------------------------------------ + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Synchronously execute *fn* with *kwargs*. + + Default implementation calls ``fn(**kwargs)`` in-process. + Subclasses should override for remote/distributed execution. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + The raw return value of *fn*. + """ + return fn(**kwargs) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Asynchronously execute *fn* with *kwargs*. + + Default implementation delegates to ``execute_callable`` + synchronously. Subclasses should override for truly async + execution. + """ + return self.execute_callable(fn, kwargs, executor_options) def get_execution_data(self) -> dict[str, Any]: """Return metadata describing the execution environment. diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index 92289955..f56242e4 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import asyncio +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from orcapod.core.executors.base import PacketFunctionExecutorBase @@ -35,3 +38,33 @@ async def async_execute( packet: PacketProtocol, ) -> PacketProtocol | None: return await packet_function.direct_async_call(packet) + + # -- PythonFunctionExecutorProtocol -- + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + if inspect.iscoroutinefunction(fn): + return asyncio.run(fn(**kwargs)) + return fn(**kwargs) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + if inspect.iscoroutinefunction(fn): + return await fn(**kwargs) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: fn(**kwargs)) + + def with_options(self, **opts: Any) -> "LocalExecutor": + """Return a new ``LocalExecutor``. + + ``LocalExecutor`` carries no state, so options are ignored. + """ + return LocalExecutor() diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index 9f54812e..998c2193 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable from typing import TYPE_CHECKING, Any from orcapod.core.executors.base import PacketFunctionExecutorBase @@ -157,6 +158,34 @@ async def async_execute( raw_result = await asyncio.wrap_future(ref.future()) return pf._build_output_packet(raw_result) + # -- PythonFunctionExecutorProtocol -- + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + import ray + + self._ensure_ray_initialized() + remote_fn = ray.remote(**self._build_remote_opts())(fn) + ref = remote_fn.remote(**kwargs) + return ray.get(ref) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + import ray + + self._ensure_ray_initialized() + remote_fn = ray.remote(**self._build_remote_opts())(fn) + ref = remote_fn.remote(**kwargs) + return await asyncio.wrap_future(ref.future()) + def with_options(self, **opts: Any) -> "RayExecutor": """Return a new ``RayExecutor`` with the given options merged in. diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 2272c616..5c96a05c 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -652,6 +652,7 @@ def function_pod( version: str = "v0.0", label: str | None = None, result_database: ArrowDatabaseProtocol | None = None, + pod_cache_database: ArrowDatabaseProtocol | None = None, executor: PacketFunctionExecutorProtocol | None = None, **kwargs, ) -> Callable[..., CallableWithPod]: @@ -662,7 +663,11 @@ def function_pod( function_name: Name of the function pod; defaults to ``func.__name__``. version: Version string for the packet function. label: Optional label for tracking. - result_database: Optional database for caching results. + result_database: Optional database for packet-level caching + (wraps the packet function in ``CachedPacketFunction``). + pod_cache_database: Optional database for pod-level caching + (wraps the pod in ``CachedFunctionPod``, which caches at the + ``process_packet(tag, packet)`` level using tag+packet hash). executor: Optional executor for running the packet function. **kwargs: Forwarded to ``PythonPacketFunction``. @@ -692,10 +697,19 @@ def decorator(func: Callable) -> CallableWithPod: ) # Create a simple typed function pod - pod = FunctionPod( + pod: _FunctionPodBase = FunctionPod( packet_function=packet_function, ) + # if pod_cache_database is provided, wrap in CachedFunctionPod + if pod_cache_database is not None: + from orcapod.core.cached_function_pod import CachedFunctionPod + + pod = CachedFunctionPod( + function_pod=pod, + result_database=pod_cache_database, + ) + @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index 140c7c7b..af456d05 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -101,11 +101,6 @@ def __init__( self._input_stream = input_stream - # Per-node Ray (or other engine) resource overrides. When a pipeline - # is run with an execution_engine, these opts are merged on top of the - # pipeline-level execution_engine_opts for this node only. - self.execution_engine_opts: dict[str, Any] | None = None - # stream-level caching state (iterator acquired lazily on first use) self._cached_input_iterator: ( Iterator[tuple[TagProtocol, PacketProtocol]] | None @@ -282,7 +277,6 @@ def from_descriptor( node._packet_function = None node._input_stream = None node.tracker_manager = DEFAULT_TRACKER_MANAGER - node.execution_engine_opts = descriptor.get("execution_engine_opts") node._cached_input_iterator = None node._needs_iterator = True node._cached_output_packets = {} diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index c8b73c7f..851a66b7 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -7,7 +7,8 @@ from abc import abstractmethod from collections.abc import Callable, Iterable, Sequence from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal +import typing +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar from uuid_utils import uuid7 @@ -20,7 +21,10 @@ get_function_signature, ) from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol -from orcapod.protocols.core_protocols.executor import PacketFunctionExecutorProtocol +from orcapod.protocols.core_protocols.executor import ( + PacketFunctionExecutorProtocol, + PythonFunctionExecutorProtocol, +) from orcapod.protocols.database_protocols import ArrowDatabaseProtocol from orcapod.system_constants import constants from orcapod.types import DataValue, Schema, SchemaLike @@ -101,8 +105,29 @@ def parse_function_outputs( return dict(zip(output_keys, output_values)) -class PacketFunctionBase(TraceableBase): - """Abstract base class for PacketFunctionProtocol.""" +E = TypeVar("E", bound=PacketFunctionExecutorProtocol) + + +class PacketFunctionBase(TraceableBase, Generic[E]): + """Abstract base class for PacketFunctionProtocol. + + Type-parameterized with the executor protocol ``E``. Concrete + subclasses that bind ``E`` (e.g. ``class Foo(PacketFunctionBase[SomeProto])``) + get automatic ``isinstance`` validation in ``set_executor`` at class + definition time via ``__init_subclass__``. + """ + + _resolved_executor_protocol: ClassVar[type | None] = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for base in getattr(cls, "__orig_bases__", ()): + origin = typing.get_origin(base) + if origin is PacketFunctionBase: + args = typing.get_args(base) + if args and not isinstance(args[0], TypeVar): + cls._resolved_executor_protocol = args[0] + return def __init__( self, @@ -226,16 +251,35 @@ def executor(self) -> PacketFunctionExecutorProtocol | None: def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: """Set or clear the executor for this packet function. + Delegates to ``set_executor`` for validation. + """ + self.set_executor(executor) + + def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor, validating type compatibility. + + Performs two checks: + 1. The executor supports this function's ``packet_function_type_id``. + 2. If the subclass bound ``E`` via ``Generic[E]``, the executor is an + instance of the resolved protocol (checked once at assignment time, + not in the hot path). + Raises: - TypeError: If *executor* does not support this function's - ``packet_function_type_id``. + TypeError: If *executor* fails either compatibility check. """ - if executor is not None and not executor.supports(self.packet_function_type_id): - raise TypeError( - f"Executor {executor.executor_type_id!r} does not support " - f"packet function type {self.packet_function_type_id!r}. " - f"Supported types: {executor.supported_function_type_ids()}" - ) + if executor is not None: + if not executor.supports(self.packet_function_type_id): + raise TypeError( + f"Executor {executor.executor_type_id!r} does not support " + f"packet function type {self.packet_function_type_id!r}. " + f"Supported types: {executor.supported_function_type_ids()}" + ) + proto = getattr(type(self), "_resolved_executor_protocol", None) + if proto is not None and not isinstance(executor, proto): + raise TypeError( + f"{type(self).__name__} requires an executor implementing " + f"{proto.__name__}, got {type(executor).__name__}" + ) self._executor = executor # ==================== Execution ==================== @@ -273,7 +317,7 @@ async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | No ... -class PythonPacketFunction(PacketFunctionBase): +class PythonPacketFunction(PacketFunctionBase[PythonFunctionExecutorProtocol]): @property def packet_function_type_id(self) -> str: """Unique function type identifier.""" @@ -463,6 +507,31 @@ def _call_async_function_sync(self, packet: PacketProtocol) -> Any: _get_sync_executor().submit(lambda: asyncio.run(fn(**kwargs))).result() ) + def call(self, packet: PacketProtocol) -> PacketProtocol | None: + """Process a single packet, routing through the executor if one is set. + + When an executor implementing ``PythonFunctionExecutorProtocol`` is + set, the raw callable and kwargs are handed to + ``execute_callable`` and the result is wrapped into an output packet. + """ + if self._executor is not None: + if not self._active: + return None + raw = self._executor.execute_callable(self._function, packet.as_dict()) + return self._build_output_packet(raw) + return self.direct_call(packet) + + async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + """Async counterpart of ``call``.""" + if self._executor is not None: + if not self._active: + return None + raw = await self._executor.async_execute_callable( + self._function, packet.as_dict() + ) + return self._build_output_packet(raw) + return await self.direct_async_call(packet) + def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: """Execute the function on *packet* synchronously. @@ -544,8 +613,13 @@ def from_config(cls, config: dict[str, Any]) -> "PythonPacketFunction": ) -class PacketFunctionWrapper(PacketFunctionBase): - """Wrapper around a PacketFunctionProtocol to modify or extend its behavior.""" +class PacketFunctionWrapper(PacketFunctionBase[E]): + """Wrapper around a PacketFunctionProtocol to modify or extend its behavior. + + Remains generic over ``E`` — the executor protocol is not bound here + so that wrappers inherit the executor type constraint of the wrapped + function. + """ def __init__(self, packet_function: PacketFunctionProtocol, **kwargs) -> None: super().__init__(**kwargs) diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index c0bf574d..6251074a 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -375,11 +375,10 @@ def run( execution_engine: Optional packet-function executor applied to every function node before execution (e.g. a ``RayExecutor``). Overrides ``config.execution_engine`` when both are provided. - execution_engine_opts: Default resource/options dict forwarded to - the engine for every node (e.g. ``{"num_cpus": 4}``). - Individual nodes may override via their - ``execution_engine_opts`` attribute. Overrides - ``config.execution_engine_opts`` when both are provided. + execution_engine_opts: Resource/options dict forwarded to the + engine via ``with_options()`` (e.g. ``{"num_cpus": 4}``). + Overrides ``config.execution_engine_opts`` when both are + provided. """ from orcapod.types import ExecutorType, PipelineConfig @@ -426,38 +425,35 @@ def _apply_execution_engine( ) -> None: """Apply *execution_engine* to every ``FunctionNode`` in the pipeline. - For each function node, the pipeline-level *execution_engine_opts* are - merged with any per-node ``execution_engine_opts`` override (node opts - win). If the merged opts dict is non-empty, ``engine.with_options`` - is called to produce a node-specific executor; otherwise the engine + If *execution_engine_opts* is non-empty, ``engine.with_options`` + is called to produce a configured executor; otherwise the engine instance is used directly. Args: execution_engine: Executor to apply (must implement ``PacketFunctionExecutorBase`` or at minimum expose ``with_options``). - execution_engine_opts: Pipeline-level default options dict, or + execution_engine_opts: Pipeline-level options dict, or ``None`` for no defaults. """ assert self._node_graph is not None, ( "_apply_execution_engine called before compile()" ) - pipeline_opts = execution_engine_opts or {} + opts = execution_engine_opts or {} + configured_executor = ( + execution_engine.with_options(**opts) if opts else execution_engine + ) for node in self._node_graph.nodes: if not isinstance(node, FunctionNode): continue - node_opts = node.execution_engine_opts or {} - merged = {**pipeline_opts, **node_opts} - node.executor = ( - execution_engine.with_options(**merged) if merged else execution_engine - ) + node.executor = configured_executor logger.debug( "Applied execution engine %r to node %r (opts=%r)", type(execution_engine).__name__, node.label, - merged or None, + opts or None, ) def _run_async(self, config: PipelineConfig) -> None: @@ -600,7 +596,6 @@ def _build_function_descriptor(self, node: "FunctionNode") -> dict[str, Any]: "function_pod": node._function_pod.to_config(), "pipeline_path": list(node.pipeline_path), "result_record_path": list(node._packet_function.record_path), - "execution_engine_opts": node.execution_engine_opts, } def _build_operator_descriptor(self, node: OperatorNode) -> dict[str, Any]: diff --git a/src/orcapod/protocols/core_protocols/__init__.py b/src/orcapod/protocols/core_protocols/__init__.py index 76c2720a..1f274033 100644 --- a/src/orcapod/protocols/core_protocols/__init__.py +++ b/src/orcapod/protocols/core_protocols/__init__.py @@ -3,7 +3,7 @@ from .async_executable import AsyncExecutableProtocol from .datagrams import DatagramProtocol, PacketProtocol, TagProtocol -from .executor import PacketFunctionExecutorProtocol +from .executor import PacketFunctionExecutorProtocol, PythonFunctionExecutorProtocol from .function_pod import FunctionPodProtocol from .operator_pod import OperatorPodProtocol from .packet_function import PacketFunctionProtocol @@ -27,6 +27,7 @@ "OperatorPodProtocol", "PacketFunctionProtocol", "PacketFunctionExecutorProtocol", + "PythonFunctionExecutorProtocol", "TrackerProtocol", "TrackerManagerProtocol", ] diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index 81a88db4..f3055805 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable from orcapod.protocols.core_protocols.datagrams import PacketProtocol @@ -63,11 +64,13 @@ def supports_concurrent_execution(self) -> bool: ... def with_options(self, **opts: Any) -> Self: - """Return an executor configured with the given per-node options. + """Return a **new** executor instance configured with the given per-node options. Used by the pipeline to produce node-specific executor instances (e.g. with different CPU/GPU allocations) from a shared base executor. - Implementations that do not support options may return *self*. + Implementations must always return a new instance, even when no + options change, so that executors are effectively immutable value + objects after construction. """ ... @@ -78,3 +81,70 @@ def get_execution_data(self) -> dict[str, Any]: affect content or pipeline hashes. """ ... + + +@runtime_checkable +class PythonFunctionExecutorProtocol(Protocol): + """Executor protocol for Python callable-based packet functions. + + Unlike ``PacketFunctionExecutorProtocol`` which operates on + (packet_function, packet) pairs, this protocol operates on raw + Python callables — the executor receives the function and its + keyword arguments directly. The packet function handles + packet construction/deconstruction around the executor call. + """ + + @property + def executor_type_id(self) -> str: + """Unique identifier for this executor type.""" + ... + + @property + def supports_concurrent_execution(self) -> bool: + """Whether this executor can run multiple calls concurrently.""" + ... + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Synchronously execute *fn* with *kwargs*. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options (e.g. resource + overrides). + + Returns: + The raw return value of *fn*. + """ + ... + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> Any: + """Asynchronously execute *fn* with *kwargs*. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + The raw return value of *fn*. + """ + ... + + def with_options(self, **opts: Any) -> Self: + """Return a **new** executor instance with the given options merged in.""" + ... + + def get_execution_data(self) -> dict[str, Any]: + """Return metadata describing the execution environment.""" + ... diff --git a/src/orcapod/types.py b/src/orcapod/types.py index a7ded54e..3f8938de 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -275,9 +275,8 @@ class PipelineConfig: execution_engine: Optional packet-function executor applied to all function nodes (e.g. ``RayExecutor``). ``None`` means in-process execution. - execution_engine_opts: Default resource/options dict forwarded to the - engine for every node (e.g. ``{"num_cpus": 4}``). Individual - nodes may override via their ``execution_engine_opts`` attribute. + execution_engine_opts: Resource/options dict forwarded to the engine + via ``with_options()`` (e.g. ``{"num_cpus": 4}``). """ executor: ExecutorType = ExecutorType.SYNCHRONOUS diff --git a/tests/test_core/function_pod/test_cached_function_pod.py b/tests/test_core/function_pod/test_cached_function_pod.py new file mode 100644 index 00000000..24f16b37 --- /dev/null +++ b/tests/test_core/function_pod/test_cached_function_pod.py @@ -0,0 +1,202 @@ +"""Tests for CachedFunctionPod — pod-level caching wrapper.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.cached_function_pod import CachedFunctionPod +from orcapod.core.function_pod import FunctionPod, function_pod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_stream(rows: list[dict] | None = None) -> ArrowTableStream: + if rows is None: + rows = [{"id": 0, "x": 10}, {"id": 1, "x": 20}] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cache_db(): + return InMemoryArrowDatabase() + + +@pytest.fixture +def double_pod(): + pf = PythonPacketFunction(double, output_keys="result") + return FunctionPod(pf) + + +@pytest.fixture +def cached_pod(double_pod, cache_db): + return CachedFunctionPod( + function_pod=double_pod, + result_database=cache_db, + ) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_record_path_ends_with_inner_uri(self, cached_pod, double_pod): + assert cached_pod.record_path[-len(double_pod.uri) :] == double_pod.uri + + def test_record_path_prefix_empty_by_default(self, cached_pod, double_pod): + assert cached_pod.record_path == double_pod.uri + + def test_record_path_prefix_prepended(self, double_pod, cache_db): + pod = CachedFunctionPod( + double_pod, + result_database=cache_db, + record_path_prefix=("my", "prefix"), + ) + assert pod.record_path[:2] == ("my", "prefix") + + +# --------------------------------------------------------------------------- +# Cache miss +# --------------------------------------------------------------------------- + + +class TestCacheMiss: + def test_returns_non_none_result(self, cached_pod): + stream = _make_stream() + output = cached_pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 2 + + def test_result_values_correct(self, cached_pod): + stream = _make_stream() + output = cached_pod.process(stream) + results = list(output.iter_packets()) + assert results[0][1].as_dict()["result"] == 20 + assert results[1][1].as_dict()["result"] == 40 + + def test_result_stored_in_database(self, cached_pod, cache_db): + stream = _make_stream() + output = cached_pod.process(stream) + list(output.iter_packets()) # exhaust the iterator + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 2 + + +# --------------------------------------------------------------------------- +# Cache hit +# --------------------------------------------------------------------------- + + +class TestCacheHit: + def test_second_call_returns_same_result(self, cached_pod): + stream = _make_stream() + + # First call: cache miss + output1 = cached_pod.process(stream) + first = [(t.as_dict(), p.as_dict()) for t, p in output1.iter_packets()] + + # Second call: cache hit + output2 = cached_pod.process(_make_stream()) + second = [(t.as_dict(), p.as_dict()) for t, p in output2.iter_packets()] + + assert len(first) == len(second) + for (_, p1), (_, p2) in zip(first, second): + assert p1["result"] == p2["result"] + + def test_second_call_does_not_add_new_records(self, cached_pod, cache_db): + stream = _make_stream() + output = cached_pod.process(stream) + list(output.iter_packets()) + + records_after_first = cache_db.get_all_records(cached_pod.record_path) + assert records_after_first is not None + count_first = records_after_first.num_rows + + output2 = cached_pod.process(_make_stream()) + list(output2.iter_packets()) + + records_after_second = cache_db.get_all_records(cached_pod.record_path) + assert records_after_second is not None + assert records_after_second.num_rows == count_first + + +# --------------------------------------------------------------------------- +# Tag-aware caching +# --------------------------------------------------------------------------- + + +class TestTagAwareCaching: + def test_different_tags_same_packet_cached_separately(self, double_pod, cache_db): + """Same packet data with different tag values should be cached separately.""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + # Stream with tag=0, x=10 + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + # Stream with tag=1, x=10 (same packet data, different tag) + stream2 = _make_stream([{"id": 1, "x": 10}]) + list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + # Should have 2 records since tags differ + assert records.num_rows == 2 + + +# --------------------------------------------------------------------------- +# Decorator integration +# --------------------------------------------------------------------------- + + +class TestDecoratorIntegration: + def test_decorator_with_pod_cache_database(self): + db = InMemoryArrowDatabase() + + @function_pod(output_keys="result", pod_cache_database=db) + def my_double(x: int) -> int: + return x * 2 + + assert isinstance(my_double.pod, CachedFunctionPod) + + def test_decorator_pod_cache_produces_correct_results(self): + db = InMemoryArrowDatabase() + + @function_pod(output_keys="result", pod_cache_database=db) + def my_double(x: int) -> int: + return x * 2 + + stream = _make_stream() + output = my_double.pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 2 + assert results[0][1].as_dict()["result"] == 20 + + def test_decorator_without_pod_cache_returns_plain_pod(self): + @function_pod(output_keys="result") + def my_double(x: int) -> int: + return x * 2 + + assert isinstance(my_double.pod, FunctionPod) diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 0ac1e9f3..8e855654 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -26,6 +26,7 @@ PacketFunctionExecutorProtocol, PacketFunctionProtocol, PacketProtocol, + PythonFunctionExecutorProtocol, ) # --------------------------------------------------------------------------- @@ -63,6 +64,10 @@ def execute( self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.calls.append((fn, kwargs)) + return fn(**kwargs) + class PythonOnlyExecutor(PacketFunctionExecutorBase): """Executor that only supports python.function.v0.""" @@ -147,6 +152,19 @@ def test_get_execution_data_returns_type(self): data = executor.get_execution_data() assert data["executor_type"] == "spy" + def test_with_options_returns_new_instance(self): + executor = SpyExecutor() + new_executor = executor.with_options() + assert new_executor is not executor + assert isinstance(new_executor, SpyExecutor) + + def test_with_options_preserves_state(self): + executor = SpyExecutor(supported_types=frozenset({"python.function.v0"})) + new_executor = executor.with_options() + assert new_executor.supported_function_type_ids() == frozenset( + {"python.function.v0"} + ) + # --------------------------------------------------------------------------- # 2. LocalExecutor @@ -175,6 +193,11 @@ def test_get_execution_data(self, local_executor: LocalExecutor): data = local_executor.get_execution_data() assert data["executor_type"] == "local" + def test_with_options_returns_new_instance(self, local_executor: LocalExecutor): + new_executor = local_executor.with_options() + assert new_executor is not local_executor + assert isinstance(new_executor, LocalExecutor) + # --------------------------------------------------------------------------- # 3. Executor as property on PacketFunctionBase @@ -238,7 +261,6 @@ def test_call_with_executor_routes_through_executor( assert result is not None assert result.as_dict()["result"] == 3 assert len(spy_executor.calls) == 1 - assert spy_executor.calls[0][0] is add_pf def test_direct_call_bypasses_executor( self, @@ -559,8 +581,8 @@ class ConcurrentSpyExecutor(PacketFunctionExecutorBase): """Executor that supports concurrent execution and tracks sync vs async calls.""" def __init__(self) -> None: - self.sync_calls: list[PacketProtocol] = [] - self.async_calls: list[PacketProtocol] = [] + self.sync_calls: list[Any] = [] + self.async_calls: list[Any] = [] @property def executor_type_id(self) -> str: @@ -589,6 +611,14 @@ async def async_execute( self.async_calls.append(packet) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.sync_calls.append(kwargs) + return fn(**kwargs) + + async def async_execute_callable(self, fn, kwargs, executor_options=None): + self.async_calls.append(kwargs) + return fn(**kwargs) + class TestConcurrentIteration: def test_function_pod_stream_uses_async_path(self): @@ -688,3 +718,71 @@ def test_second_iteration_uses_cache(self): second = list(output_stream.iter_packets()) assert len(spy.async_calls) == 2 # unchanged assert len(first) == len(second) + + +# --------------------------------------------------------------------------- +# 11. PythonFunctionExecutorProtocol conformance +# --------------------------------------------------------------------------- + + +class TestPythonFunctionExecutorProtocol: + def test_local_executor_satisfies_protocol(self): + executor = LocalExecutor() + assert isinstance(executor, PythonFunctionExecutorProtocol) + + def test_spy_executor_satisfies_protocol(self): + executor = SpyExecutor() + assert isinstance(executor, PythonFunctionExecutorProtocol) + + def test_execute_callable_runs_function(self): + executor = LocalExecutor() + result = executor.execute_callable(add, {"x": 3, "y": 4}) + assert result == 7 + + def test_execute_callable_with_executor_options(self): + executor = LocalExecutor() + result = executor.execute_callable( + add, {"x": 1, "y": 2}, executor_options={"num_cpus": 1} + ) + assert result == 3 + + +# --------------------------------------------------------------------------- +# 12. Type-safe executor dispatch via Generic[E] + __init_subclass__ +# --------------------------------------------------------------------------- + + +class _CustomExecutorProtocol: + """A non-executor class to test type-safe dispatch rejection.""" + + pass + + +class TestGenericExecutorDispatch: + def test_python_pf_resolves_executor_protocol(self): + """PythonPacketFunction should have resolved PythonFunctionExecutorProtocol.""" + assert ( + PythonPacketFunction._resolved_executor_protocol + is PythonFunctionExecutorProtocol + ) + + def test_set_executor_accepts_compatible_protocol(self): + pf = PythonPacketFunction(add, output_keys="result") + executor = LocalExecutor() + pf.set_executor(executor) + assert pf.executor is executor + + def test_set_executor_accepts_none(self): + pf = PythonPacketFunction(add, output_keys="result", executor=LocalExecutor()) + pf.set_executor(None) + assert pf.executor is None + + def test_call_routes_through_execute_callable(self): + """PythonPacketFunction.call() should use execute_callable, not execute.""" + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + packet = Packet({"x": 1, "y": 2}) + result = pf.call(packet) + assert result is not None + assert result.as_dict()["result"] == 3 + assert len(spy.calls) == 1 diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index 809b5a7a..793f0efe 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -82,6 +82,10 @@ def execute( self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.calls.append((fn, kwargs)) + return fn(**kwargs) + # =========================================================================== # 1. async_execute output channel closed on exception (try/finally) diff --git a/tests/test_pipeline/test_node_descriptors.py b/tests/test_pipeline/test_node_descriptors.py index ed86f18c..bd2fb770 100644 --- a/tests/test_pipeline/test_node_descriptors.py +++ b/tests/test_pipeline/test_node_descriptors.py @@ -154,7 +154,6 @@ def _make_function_node_descriptor(self): "function_pod": pod.to_config(), "pipeline_path": list(node.pipeline_path), "result_record_path": list(node._packet_function.record_path), - "execution_engine_opts": None, } return node, descriptor, db diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 31af0195..b5f34aa8 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1176,6 +1176,14 @@ async def async_execute( self.async_calls.append(packet) return packet_function.direct_call(packet) + def execute_callable(self, fn, kwargs, executor_options=None): + self.sync_calls.append(kwargs) + return fn(**kwargs) + + async def async_execute_callable(self, fn, kwargs, executor_options=None): + self.async_calls.append(kwargs) + return fn(**kwargs) + def with_options(self, **opts: Any) -> "_MockExecutor": return _MockExecutor(opts={**self.opts, **opts}) @@ -1236,26 +1244,24 @@ def test_explicit_sync_config_overrides_async_default(self, pipeline_db): assert len(mock.sync_calls) > 0 assert len(mock.async_calls) == 0 - def test_per_node_opts_override_pipeline_opts(self, pipeline_db): - """Node-level execution_engine_opts win over pipeline-level defaults.""" + def test_pipeline_opts_applied_via_with_options(self, pipeline_db): + """Pipeline-level execution_engine_opts are applied via with_options.""" src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) pf = PythonPacketFunction(double_value, output_keys="result") pod = FunctionPod(packet_function=pf) mock = _MockExecutor() - pipeline = Pipeline(name="test_node_opts", pipeline_database=pipeline_db) + pipeline = Pipeline(name="test_pipeline_opts", pipeline_database=pipeline_db) with pipeline: pod(src, label="doubler") - pipeline.doubler.execution_engine_opts = {"num_cpus": 2} - pipeline.run( execution_engine=mock, - execution_engine_opts={"num_cpus": 1}, + execution_engine_opts={"num_cpus": 4}, ) - # Node opts win: executor should have been created with num_cpus=2 - assert pipeline.doubler.executor.opts.get("num_cpus") == 2 + # Executor should have been created with the pipeline opts + assert pipeline.doubler.executor.opts.get("num_cpus") == 4 class TestSourceNodesInPipeline: From 865b9016116ecf5b89690d63ea142ee8bde4ba14 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 05:31:57 +0000 Subject: [PATCH 2/8] test(executors): add thorough coverage for new executor features - Generic[E] dispatch: wrapper doesn't resolve, protocol rejection for non-conforming executor, inactive function + executor returns None, async_call routes through async_execute_callable - LocalExecutor callable: async fn via execute_callable, sync/async fn via async_execute_callable - CachedFunctionPod: same tag+different packet cached separately, same tag+same packet is cache hit, inactive function doesn't store, output_schema delegation, dual caching (result_database + pod_cache_database) - Pipeline: no opts uses engine directly (no with_options call) Co-Authored-By: Claude Opus 4.6 (1M context) --- .../function_pod/test_cached_function_pod.py | 99 +++++++++++++++++++ .../packet_function/test_executor.py | 87 +++++++++++++++- tests/test_pipeline/test_pipeline.py | 17 ++++ 3 files changed, 200 insertions(+), 3 deletions(-) diff --git a/tests/test_core/function_pod/test_cached_function_pod.py b/tests/test_core/function_pod/test_cached_function_pod.py index 24f16b37..ab984611 100644 --- a/tests/test_core/function_pod/test_cached_function_pod.py +++ b/tests/test_core/function_pod/test_cached_function_pod.py @@ -165,6 +165,105 @@ def test_different_tags_same_packet_cached_separately(self, double_pod, cache_db # Should have 2 records since tags differ assert records.num_rows == 2 + def test_same_tag_different_packet_cached_separately(self, double_pod, cache_db): + """Same tag value with different packet data should produce separate entries.""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + stream2 = _make_stream([{"id": 0, "x": 99}]) + list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 2 + + def test_same_tag_same_packet_is_cache_hit(self, double_pod, cache_db): + """Exact same tag + packet should be a cache hit (no new record).""" + cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) + + stream1 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream1).iter_packets()) + + # Identical stream + stream2 = _make_stream([{"id": 0, "x": 10}]) + list(cached_pod.process(stream2).iter_packets()) + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is not None + assert records.num_rows == 1 + + +# --------------------------------------------------------------------------- +# Decorator integration +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Inner function returns None (inactive / filtering) +# --------------------------------------------------------------------------- + + +class TestInnerReturnsNone: + def test_inactive_function_does_not_store(self, cache_db): + pf = PythonPacketFunction(double, output_keys="result") + pf.set_active(False) + pod = FunctionPod(pf) + cached_pod = CachedFunctionPod(pod, result_database=cache_db) + + stream = _make_stream([{"id": 0, "x": 10}]) + results = list(cached_pod.process(stream).iter_packets()) + assert len(results) == 0 + + records = cache_db.get_all_records(cached_pod.record_path) + assert records is None + + +# --------------------------------------------------------------------------- +# Output schema delegation +# --------------------------------------------------------------------------- + + +class TestOutputSchema: + def test_output_schema_delegates_to_inner_pod(self, cached_pod, double_pod): + stream = _make_stream() + expected = double_pod.output_schema(stream) + actual = cached_pod.output_schema(stream) + assert actual == expected + + +# --------------------------------------------------------------------------- +# Dual caching (CachedPacketFunction + CachedFunctionPod) +# --------------------------------------------------------------------------- + + +class TestDualCaching: + def test_both_result_database_and_pod_cache_database(self): + """Decorator with both caching layers produces correct results.""" + pkt_db = InMemoryArrowDatabase() + pod_db = InMemoryArrowDatabase() + + @function_pod( + output_keys="result", + result_database=pkt_db, + pod_cache_database=pod_db, + ) + def my_double(x: int) -> int: + return x * 2 + + stream = _make_stream() + output = my_double.pod.process(stream) + results = list(output.iter_packets()) + + assert len(results) == 2 + assert results[0][1].as_dict()["result"] == 20 + + # Pod-level cache should have entries + pod_records = pod_db.get_all_records(my_double.pod.record_path) + assert pod_records is not None + assert pod_records.num_rows == 2 + # --------------------------------------------------------------------------- # Decorator integration diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 8e855654..317a2cb1 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -752,10 +752,18 @@ def test_execute_callable_with_executor_options(self): # --------------------------------------------------------------------------- -class _CustomExecutorProtocol: - """A non-executor class to test type-safe dispatch rejection.""" +class _NotAnExecutor: + """A class that does NOT satisfy PythonFunctionExecutorProtocol.""" - pass + @property + def executor_type_id(self) -> str: + return "fake" + + def supported_function_type_ids(self) -> frozenset[str]: + return frozenset() + + def supports(self, packet_function_type_id: str) -> bool: + return True class TestGenericExecutorDispatch: @@ -766,6 +774,10 @@ def test_python_pf_resolves_executor_protocol(self): is PythonFunctionExecutorProtocol ) + def test_wrapper_does_not_resolve_protocol(self): + """PacketFunctionWrapper[E] should NOT resolve a protocol (E is unbound).""" + assert PacketFunctionWrapper._resolved_executor_protocol is None + def test_set_executor_accepts_compatible_protocol(self): pf = PythonPacketFunction(add, output_keys="result") executor = LocalExecutor() @@ -777,6 +789,13 @@ def test_set_executor_accepts_none(self): pf.set_executor(None) assert pf.executor is None + def test_set_executor_rejects_non_conforming_protocol(self): + """An object that doesn't implement PythonFunctionExecutorProtocol is rejected.""" + pf = PythonPacketFunction(add, output_keys="result") + fake = _NotAnExecutor() + with pytest.raises(TypeError, match="requires an executor implementing"): + pf.set_executor(fake) + def test_call_routes_through_execute_callable(self): """PythonPacketFunction.call() should use execute_callable, not execute.""" spy = SpyExecutor() @@ -786,3 +805,65 @@ def test_call_routes_through_execute_callable(self): assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.calls) == 1 + + def test_call_with_inactive_function_returns_none(self): + """When the function is inactive and executor is set, call returns None.""" + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + pf.set_active(False) + packet = Packet({"x": 1, "y": 2}) + result = pf.call(packet) + assert result is None + assert len(spy.calls) == 0 + + def test_async_call_routes_through_async_execute_callable(self): + """PythonPacketFunction.async_call() should use async_execute_callable.""" + import asyncio + + spy = ConcurrentSpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) + packet = Packet({"x": 1, "y": 2}) + result = asyncio.run(pf.async_call(packet)) + assert result is not None + assert result.as_dict()["result"] == 3 + assert len(spy.async_calls) == 1 + assert len(spy.sync_calls) == 0 + + +# --------------------------------------------------------------------------- +# 13. LocalExecutor execute_callable with async function +# --------------------------------------------------------------------------- + + +class TestLocalExecutorCallable: + def test_execute_callable_with_async_fn(self): + """LocalExecutor.execute_callable handles async functions via asyncio.run.""" + import asyncio + + async def async_add(x: int, y: int) -> int: + return x + y + + executor = LocalExecutor() + result = executor.execute_callable(async_add, {"x": 5, "y": 3}) + assert result == 8 + + def test_async_execute_callable_with_sync_fn(self): + """LocalExecutor.async_execute_callable handles sync fns via run_in_executor.""" + import asyncio + + executor = LocalExecutor() + result = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) + assert result == 30 + + def test_async_execute_callable_with_async_fn(self): + """LocalExecutor.async_execute_callable awaits async functions directly.""" + import asyncio + + async def async_add(x: int, y: int) -> int: + return x + y + + executor = LocalExecutor() + result = asyncio.run( + executor.async_execute_callable(async_add, {"x": 7, "y": 8}) + ) + assert result == 15 diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index b5f34aa8..7c33dcd8 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1263,6 +1263,23 @@ def test_pipeline_opts_applied_via_with_options(self, pipeline_db): # Executor should have been created with the pipeline opts assert pipeline.doubler.executor.opts.get("num_cpus") == 4 + def test_no_opts_uses_engine_directly(self, pipeline_db): + """Without execution_engine_opts, the engine itself is assigned (no with_options).""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(packet_function=pf) + mock = _MockExecutor() + + pipeline = Pipeline(name="test_no_opts", pipeline_database=pipeline_db) + with pipeline: + pod(src, label="doubler") + + pipeline.run(execution_engine=mock) + + # Without opts, the original mock executor is assigned directly + assert pipeline.doubler.executor is mock + assert pipeline.doubler.executor.opts == {} + class TestSourceNodesInPipeline: """Verify that source nodes are first-class pipeline members.""" From aa4c1cf63bd2cab233dd12a845c395138ac4648a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 06:21:34 +0000 Subject: [PATCH 3/8] refactor(function-node): use CachedFunctionPod for result caching - CachedFunctionPod now aligns with CachedPacketFunction storage: stores function variation data, execution data, cache entry hash, and timestamp - Cache entry hash computed from tag + system tags + input packet hash (matching pipeline record entry_id pattern), ensuring two rows with identical user tags but different system tags are cached separately - FunctionNode.attach_databases() creates CachedFunctionPod wrapping the function pod instead of CachedPacketFunction wrapping the packet function - FunctionNode.process_packet() delegates to CachedFunctionPod for result caching and separately records pipeline provenance entries - CachedFunctionPod.async_process_packet() does sync DB caching + async computation via inner pod's async_process_packet - add_pipeline_record() now explicitly extracts source columns using select() instead of rename-then-drop pattern - iter_packets Phase 2 skip-check uses single cache entry hash - Added DESIGN_ISSUES.md note (CFP1) about potential optimization of reusing entry_hash between CachedFunctionPod and add_pipeline_record Co-Authored-By: Claude Opus 4.6 (1M context) --- DESIGN_ISSUES.md | 19 ++ src/orcapod/core/cached_function_pod.py | 169 +++++++++++++---- src/orcapod/core/nodes/function_node.py | 172 ++++++++++-------- src/orcapod/pipeline/graph.py | 2 +- .../test_function_node_attach_db.py | 16 +- .../function_pod/test_function_pod_node.py | 14 +- tests/test_pipeline/test_node_descriptors.py | 2 +- tests/test_pipeline/test_pipeline.py | 8 +- 8 files changed, 267 insertions(+), 135 deletions(-) diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 945d6279..8cf66bff 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -315,6 +315,25 @@ which column groups (meta, source, system_tags) are returned. --- +## `src/orcapod/core/cached_function_pod.py` + +### CFP1 — Consider single-column lookup optimization for pipeline record entry_id +**Status:** open +**Severity:** low + +`CachedFunctionPod` and `FunctionNode.add_pipeline_record` both compute a combined hash +from `tag.as_table(columns={"system_tags": True}) + input_packet_hash`. The +`CachedFunctionPod` uses this as a single-column DB lookup key (`CACHE_ENTRY_HASH_COL`), +while the pipeline record stores it as `entry_id`. + +Currently `FunctionNode.add_pipeline_record` recomputes this hash independently. Consider +benchmarking whether passing the already-computed `entry_hash` from `CachedFunctionPod` +through to `add_pipeline_record` would yield a meaningful speedup. The hash computation +involves Arrow table construction + semantic hashing, so avoiding the second computation +could be worthwhile for large pipelines. + +--- + ## `src/orcapod/core/nodes/function_node.py` ### FN1 — `FunctionNode.async_execute` Phase 2 was fully sequential diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 29ec2de2..365fa0bb 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from orcapod.core.function_pod import WrappedFunctionPod @@ -30,60 +31,80 @@ class CachedFunctionPod(WrappedFunctionPod): Unlike ``CachedPacketFunction`` (which caches at the ``call(packet)`` level using only the packet content hash as the cache key), this wrapper operates at the ``process_packet(tag, packet)`` level and - incorporates *both* the tag content hash and the packet content hash - into the cache key. + incorporates the tag (including system tags) and the packet content + hash into a single cache entry hash. - This is useful when the same packet data may produce different results - depending on tag metadata, or when tag-level deduplication is desired. + The cache entry hash is computed the same way as the pipeline record's + ``entry_id``: ``arrow_hasher.hash_table(tag_with_system_tags + input_packet_hash)``. + This ensures two rows with identical user tags but different system + tags (reflecting different source entries) are cached separately. + + Storage format aligns with ``CachedPacketFunction``: each cached + record includes function variation data, execution data, the cache + entry hash, and a timestamp. """ + # Column storing the combined hash of tag + system tags + input packet hash + CACHE_ENTRY_HASH_COL = f"{constants.META_PREFIX}cache_entry_hash" + + # Meta column indicating whether the result was freshly computed + RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + def __init__( self, function_pod: FunctionPodProtocol, result_database: ArrowDatabaseProtocol, record_path_prefix: tuple[str, ...] = (), + auto_flush: bool = True, **kwargs, ) -> None: super().__init__(function_pod, **kwargs) self._result_database = result_database self._record_path_prefix = record_path_prefix + self._auto_flush = auto_flush @property def record_path(self) -> tuple[str, ...]: """Return the path to the cached records in the result store.""" return self._record_path_prefix + self.uri - def _compute_cache_key(self, tag: TagProtocol, packet: PacketProtocol) -> str: - """Compute a cache key from tag and packet content hashes. + def _compute_entry_hash(self, tag: TagProtocol, packet: PacketProtocol) -> str: + """Compute a cache entry hash from tag (with system tags) and packet. + + The hash includes user-facing tag columns, system tag columns, and + the input packet content hash — matching the pipeline record's + entry_id computation. Args: tag: The tag associated with the packet. packet: The input packet. Returns: - A string combining tag and packet content hashes. + A hash string uniquely identifying this (tag, system_tags, packet) + combination. """ - tag_hash = tag.content_hash().to_string() - packet_hash = packet.content_hash().to_string() - return f"{tag_hash}:{packet_hash}" + tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( + constants.INPUT_PACKET_HASH_COL, + pa.array([packet.content_hash().to_string()], type=pa.large_string()), + ) + return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() - def _lookup(self, cache_key: str) -> PacketProtocol | None: - """Look up a cached output packet by cache key. + def _lookup(self, entry_hash: str) -> PacketProtocol | None: + """Look up a cached output packet by cache entry hash. Args: - cache_key: The combined tag+packet hash key. + entry_hash: The combined tag+system_tags+packet hash. Returns: The cached output packet, or ``None`` if not found. """ from orcapod.core.datagrams import Packet - CACHE_KEY_COL = f"{constants.META_PREFIX}pod_cache_key" RECORD_ID_COL = "_record_id" result_table = self._result_database.get_records_with_column_value( self.record_path, - {CACHE_KEY_COL: cache_key}, + {self.CACHE_ENTRY_HASH_COL: entry_hash}, record_id_column=RECORD_ID_COL, ) @@ -92,40 +113,68 @@ def _lookup(self, cache_key: str) -> PacketProtocol | None: if result_table.num_rows > 1: logger.info( - "Pod-level cache: multiple records for key %s, taking most recent", - cache_key, + "Pod-level cache: multiple records for entry hash, taking most recent" ) result_table = result_table.sort_by( [(constants.POD_TIMESTAMP, "descending")] ).take([0]) record_id = result_table.to_pylist()[0][RECORD_ID_COL] - result_table = result_table.drop_columns([RECORD_ID_COL, CACHE_KEY_COL]) + result_table = result_table.drop_columns( + [RECORD_ID_COL, self.CACHE_ENTRY_HASH_COL] + ) - return Packet(result_table, record_id=record_id) + return Packet( + result_table, + record_id=record_id, + meta_info={self.RESULT_COMPUTED_FLAG: False}, + ) def _store( self, - cache_key: str, + entry_hash: str, + input_packet: PacketProtocol, output_packet: PacketProtocol, ) -> None: """Store an output packet in the cache. + Stores the output packet data alongside function variation data, + execution data, the cache entry hash, and a timestamp — matching + the column structure of ``CachedPacketFunction``. + Args: - cache_key: The combined tag+packet hash key. + entry_hash: The combined tag+system_tags+packet hash. + input_packet: The input packet (used for its content hash). output_packet: The computed output packet to store. """ - from datetime import datetime, timezone + data_table = output_packet.as_table(columns={"source": True, "context": True}) - CACHE_KEY_COL = f"{constants.META_PREFIX}pod_cache_key" + pf = self._function_pod.packet_function - data_table = output_packet.as_table(columns={"source": True, "context": True}) + # Add function variation data columns + i = 0 + for k, v in pf.get_function_variation_data().items(): + data_table = data_table.add_column( + i, + f"{constants.PF_VARIATION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add execution data columns + for k, v in pf.get_execution_data().items(): + data_table = data_table.add_column( + i, + f"{constants.PF_EXECUTION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 - # Prepend cache key column + # Add cache entry hash (position 0) data_table = data_table.add_column( 0, - CACHE_KEY_COL, - pa.array([cache_key], type=pa.large_string()), + self.CACHE_ENTRY_HASH_COL, + pa.array([entry_hash], type=pa.large_string()), ) # Append timestamp @@ -141,16 +190,20 @@ def _store( data_table, skip_duplicates=False, ) - self._result_database.flush() + + if self._auto_flush: + self._result_database.flush() def process_packet( self, tag: TagProtocol, packet: PacketProtocol ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a packet with pod-level caching. - The cache key incorporates both tag and packet content hashes. - On a cache hit, the stored output packet is returned without - invoking the inner pod's computation. + The cache entry hash incorporates tag columns, system tag columns, + and the input packet content hash. On a cache hit, the stored + output packet is returned without invoking the inner pod's + computation. The output packet carries a ``RESULT_COMPUTED_FLAG`` + meta value: ``True`` if freshly computed, ``False`` if from cache. Args: tag: The tag associated with the packet. @@ -160,17 +213,61 @@ def process_packet( A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the inner function filters the packet out. """ - cache_key = self._compute_cache_key(tag, packet) - cached = self._lookup(cache_key) + entry_hash = self._compute_entry_hash(tag, packet) + cached = self._lookup(entry_hash) if cached is not None: - logger.info("Pod-level cache hit for key %s", cache_key) + logger.info("Pod-level cache hit") return tag, cached tag, output = self._function_pod.process_packet(tag, packet) if output is not None: - self._store(cache_key, output) + self._store(entry_hash, packet, output) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + DB lookup and store are synchronous (DB protocol is sync), but the + actual computation uses the inner pod's ``async_process_packet`` + for true async execution. + """ + entry_hash = self._compute_entry_hash(tag, packet) + cached = self._lookup(entry_hash) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached + + tag, output = await self._function_pod.async_process_packet(tag, packet) + if output is not None: + self._store(entry_hash, packet, output) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output + + def get_all_cached_outputs( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store for this pod. + + Args: + include_system_columns: If True, include system columns + (e.g. record_id) in the result. + + Returns: + A PyArrow table of cached results, or ``None`` if empty. + """ + record_id_column = ( + constants.PACKET_RECORD_ID if include_system_columns else None + ) + result_table = self._result_database.get_all_records( + self.record_path, record_id_column=record_id_column + ) + if result_table is None or result_table.num_rows == 0: + return None + return result_table + def process( self, *streams: StreamProtocol, label: str | None = None ) -> StreamProtocol: @@ -179,7 +276,7 @@ def process( The stream returned uses *this* pod's ``process_packet`` (which includes caching) rather than the inner pod's. """ - from orcapod.core.function_pod import FunctionPod, FunctionPodStream + from orcapod.core.function_pod import FunctionPodStream # Validate and prepare the input stream input_stream = self._function_pod.handle_input_streams(*streams) diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index af456d05..a385787e 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -10,7 +10,7 @@ from orcapod import contexts from orcapod.channels import ReadableChannel, WritableChannel from orcapod.config import Config -from orcapod.core.packet_function import CachedPacketFunction +from orcapod.core.cached_function_pod import CachedFunctionPod from orcapod.core.streams.arrow_table_stream import ArrowTableStream from orcapod.core.streams.base import StreamBase from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER @@ -60,7 +60,7 @@ class FunctionNode(StreamBase): When constructed without database parameters, provides the core stream interface (identity, schema, iteration) without any persistence. When databases are provided (either at construction or via ``attach_databases``), - adds result caching via ``CachedPacketFunction``, pipeline record storage, + adds result caching via ``CachedFunctionPod``, pipeline record storage, and two-phase iteration (cached first, then compute missing). """ @@ -114,6 +114,7 @@ def __init__( # DB persistence state (initially None; set via __init__ params or attach_databases) self._pipeline_database: ArrowDatabaseProtocol | None = None + self._cached_function_pod: CachedFunctionPod | None = None self._pipeline_path_prefix: tuple[str, ...] = () self._pipeline_node_hash: str | None = None self._output_schema_hash: str | None = None @@ -139,6 +140,10 @@ def attach_databases( ) -> None: """Attach databases for persistent caching and pipeline records. + Creates a ``CachedFunctionPod`` wrapping the original function pod + for result caching. The pipeline database is used separately for + pipeline-level provenance records (tag + packet hash). + Args: pipeline_database: Database for pipeline records. result_database: Database for cached results. Defaults to @@ -157,13 +162,9 @@ def attach_databases( elif result_path_prefix is not None: computed_result_path_prefix = result_path_prefix - # Guard against double-wrapping - pf = self._packet_function - if isinstance(pf, CachedPacketFunction): - pf = pf._packet_function - - self._packet_function = CachedPacketFunction( - pf, + # Always wrap the original function_pod (not a previous cached wrapper) + self._cached_function_pod = CachedFunctionPod( + self._function_pod, result_database=result_database, record_path_prefix=computed_result_path_prefix, ) @@ -285,6 +286,7 @@ def from_descriptor( # DB persistence state node._pipeline_database = pipeline_db + node._cached_function_pod = None node._pipeline_path_prefix = () node._pipeline_node_hash = None node._output_schema_hash = None @@ -467,33 +469,29 @@ def process_packet( self, tag: TagProtocol, packet: PacketProtocol, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a single packet, optionally recording to the pipeline database. + When a database is attached, uses ``CachedFunctionPod`` for result + caching (tag+packet level) and separately records a pipeline + provenance entry. + Args: tag: The tag associated with the packet. packet: The input packet to process. - skip_cache_lookup: If True, bypass DB lookup for existing result. - skip_cache_insert: If True, skip writing result to DB. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the function filters the packet out. """ - if self._pipeline_database is not None: - # Persistent mode: use CachedPacketFunction + pipeline record - output_packet = self._packet_function.call( - packet, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, - ) + if self._cached_function_pod is not None: + # Persistent mode: CachedFunctionPod for result caching + pipeline record + tag, output_packet = self._cached_function_pod.process_packet(tag, packet) if output_packet is not None: result_computed = bool( output_packet.get_meta_value( - self._packet_function.RESULT_COMPUTED_FLAG, False + self._cached_function_pod.RESULT_COMPUTED_FLAG, False ) ) self.add_pipeline_record( @@ -511,26 +509,22 @@ async def async_process_packet( self, tag: TagProtocol, packet: PacketProtocol, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``. - Uses the CachedPacketFunction's async_call for computation + result - caching when a database is attached. Pipeline record storage is - synchronous (DB protocol is sync). + Uses ``CachedFunctionPod.async_process_packet`` (sync DB caching + + async computation) when a database is attached. Pipeline record + storage is synchronous (DB protocol is sync). """ - if self._pipeline_database is not None: - output_packet = await self._packet_function.async_call( - packet, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, + if self._cached_function_pod is not None: + tag, output_packet = await self._cached_function_pod.async_process_packet( + tag, packet ) if output_packet is not None: result_computed = bool( output_packet.get_meta_value( - self._packet_function.RESULT_COMPUTED_FLAG, False + self._cached_function_pod.RESULT_COMPUTED_FLAG, False ) ) self.add_pipeline_record( @@ -552,19 +546,25 @@ def add_pipeline_record( computed: bool, skip_cache_lookup: bool = False, ) -> None: - """Add a pipeline record to the database for a processed packet.""" - # combine TagProtocol with packet content hash to compute entry hash - # TODO: add system tag columns - # TODO: consider using bytes instead of string representation + """Add a pipeline record to the database for a processed packet. + + The pipeline record stores: + - Tag columns (including system tags) + - All source columns of the input packet (provenance, not data) + - Output packet record ID (for joining with result records) + - Input packet data context key + - Whether the result was freshly computed or cached + """ + # Compute entry hash from tag + system tags + input packet hash tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( constants.INPUT_PACKET_HASH_COL, pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), ) - # unique entry ID is determined by the combination of tags, system_tags, and input_packet hash + # Unique entry ID: combination of tags, system_tags, and input_packet hash entry_id = self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() - # check presence of an existing entry with the same entry_id + # Check for existing entry existing_record = None if not skip_cache_lookup: existing_record = self._pipeline_database.get_record_by_id( @@ -573,35 +573,40 @@ def add_pipeline_record( ) if existing_record is not None: - # if the record already exists, then skip adding logger.debug( f"Record with entry_id {entry_id} already exists. Skipping addition." ) return - # rename all keys to avoid potential collision with result columns - renamed_input_packet = input_packet.rename( - {k: f"_input_{k}" for k in input_packet.keys()} - ) - input_packet_info = ( - renamed_input_packet.as_table(columns={"source": True}) - .append_column( - constants.PACKET_RECORD_ID, # record ID for the packet function output packet - pa.array([packet_record_id], type=pa.large_string()), - ) - .append_column( - f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}", # data context key for the input packet - pa.array([input_packet.data_context_key], type=pa.large_string()), - ) - .append_column( - f"{constants.META_PREFIX}computed", - pa.array([computed], type=pa.bool_()), - ) - .drop_columns(list(renamed_input_packet.keys())) + # Extract source columns only (no data columns) from the input packet + input_table_with_source = input_packet.as_table(columns={"source": True}) + source_col_names = [ + c + for c in input_table_with_source.column_names + if c.startswith(constants.SOURCE_PREFIX) + ] + input_source_table = input_table_with_source.select(source_col_names) + + # Build the meta columns table + meta_table = pa.table( + { + constants.PACKET_RECORD_ID: pa.array( + [packet_record_id], type=pa.large_string() + ), + f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}": pa.array( + [input_packet.data_context_key], type=pa.large_string() + ), + f"{constants.META_PREFIX}computed": pa.array( + [computed], type=pa.bool_() + ), + } ) + # Combine: tag (with system tags) + input source columns + meta columns combined_record = arrow_utils.hstack_tables( - tag.as_table(columns={"system_tags": True}), input_packet_info + tag.as_table(columns={"system_tags": True}), + input_source_table, + meta_table, ) self._pipeline_database.add_record( @@ -630,11 +635,11 @@ def get_all_records( A PyArrow table of joined results, or ``None`` if no database is attached or no records exist. """ - if self._pipeline_database is None: + if self._cached_function_pod is None: return None - results = self._packet_function._result_database.get_all_records( - self._packet_function.record_path, + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, record_id_column=constants.PACKET_RECORD_ID, ) taginfo = self._pipeline_database.get_all_records(self.pipeline_path) @@ -700,21 +705,21 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: self.clear_cache() self._ensure_iterator() - if self._pipeline_database is not None: + if self._cached_function_pod is not None: # Two-phase iteration with DB backing if self._cached_input_iterator is not None: input_iter = self._cached_input_iterator # --- Phase 1: yield already-computed results from the databases --- existing = self.get_all_records(columns={"meta": True}) - computed_hashes: set[str] = set() + computed_entry_hashes: set[str] = set() if existing is not None and existing.num_rows > 0: tag_keys = self._input_stream.keys()[0] - # Strip the meta column before handing to ArrowTableStream so it only - # sees tag + output-packet columns. - hash_col = constants.INPUT_PACKET_HASH_COL - hash_values = cast(list[str], existing.column(hash_col).to_pylist()) - computed_hashes = set(hash_values) - data_table = existing.drop([hash_col]) + entry_hash_col = CachedFunctionPod.CACHE_ENTRY_HASH_COL + computed_entry_hashes = set( + cast(list[str], existing.column(entry_hash_col).to_pylist()) + ) + # Strip the entry hash column before yielding as stream + data_table = existing.drop([entry_hash_col]) existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) for i, (tag, packet) in enumerate(existing_stream.iter_packets()): self._cached_output_packets[i] = (tag, packet) @@ -723,8 +728,10 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: # --- Phase 2: process only missing input packets --- next_idx = len(self._cached_output_packets) for tag, packet in input_iter: - input_hash = packet.content_hash().to_string() - if input_hash in computed_hashes: + entry_hash = self._cached_function_pod._compute_entry_hash( + tag, packet + ) + if entry_hash in computed_entry_hashes: continue tag, output_packet = self.process_packet(tag, packet) self._cached_output_packets[next_idx] = (tag, output_packet) @@ -945,18 +952,21 @@ async def async_execute( node_config = getattr(self._function_pod, "node_config", NodeConfig()) max_concurrency = resolve_concurrency(node_config, pipeline_config) - if self._pipeline_database is not None: + if self._cached_function_pod is not None: # Two-phase async execution with DB backing # Phase 1: emit existing results from DB existing = self.get_all_records(columns={"meta": True}) - computed_hashes: set[str] = set() + computed_entry_hashes: set[str] = set() if existing is not None and existing.num_rows > 0: tag_keys = self._input_stream.keys()[0] - hash_col = constants.INPUT_PACKET_HASH_COL - computed_hashes = set( - cast(list[str], existing.column(hash_col).to_pylist()) + entry_hash_col = CachedFunctionPod.CACHE_ENTRY_HASH_COL + computed_entry_hashes = set( + cast( + list[str], + existing.column(entry_hash_col).to_pylist(), + ) ) - data_table = existing.drop([hash_col]) + data_table = existing.drop([entry_hash_col]) existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) for tag, packet in existing_stream.iter_packets(): await output.send((tag, packet)) @@ -983,8 +993,10 @@ async def process_one_db( async with asyncio.TaskGroup() as tg: async for tag, packet in inputs[0]: - input_hash = packet.content_hash().to_string() - if input_hash in computed_hashes: + entry_hash = self._cached_function_pod._compute_entry_hash( + tag, packet + ) + if entry_hash in computed_entry_hashes: continue if sem is not None: await sem.acquire() diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index 6251074a..6ee216d0 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -595,7 +595,7 @@ def _build_function_descriptor(self, node: "FunctionNode") -> dict[str, Any]: return { "function_pod": node._function_pod.to_config(), "pipeline_path": list(node.pipeline_path), - "result_record_path": list(node._packet_function.record_path), + "result_record_path": list(node._cached_function_pod.record_path), } def _build_operator_descriptor(self, node: OperatorNode) -> dict[str, Any]: diff --git a/tests/test_core/function_pod/test_function_node_attach_db.py b/tests/test_core/function_pod/test_function_node_attach_db.py index 2cd6a955..cef00ef1 100644 --- a/tests/test_core/function_pod/test_function_node_attach_db.py +++ b/tests/test_core/function_pod/test_function_node_attach_db.py @@ -59,13 +59,13 @@ def test_attach_databases_sets_pipeline_db(self): node.attach_databases(pipeline_database=db, result_database=db) assert node._pipeline_database is db - def test_attach_databases_wraps_packet_function(self): - from orcapod.core.packet_function import CachedPacketFunction + def test_attach_databases_creates_cached_function_pod(self): + from orcapod.core.cached_function_pod import CachedFunctionPod node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) + assert isinstance(node._cached_function_pod, CachedFunctionPod) def test_attach_databases_clears_caches(self): node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) @@ -83,17 +83,17 @@ def test_attach_databases_computes_pipeline_path(self): assert len(node.pipeline_path) > 0 def test_double_attach_does_not_double_wrap(self): - from orcapod.core.packet_function import CachedPacketFunction + from orcapod.core.cached_function_pod import CachedFunctionPod node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) - # Second attach should not double-wrap + assert isinstance(node._cached_function_pod, CachedFunctionPod) + # Second attach wraps the original function_pod, not the cached one node.attach_databases(pipeline_database=db, result_database=db) - assert isinstance(node._packet_function, CachedPacketFunction) + assert isinstance(node._cached_function_pod, CachedFunctionPod) assert not isinstance( - node._packet_function._packet_function, CachedPacketFunction + node._cached_function_pod._function_pod, CachedFunctionPod ) def test_iter_packets_after_attach_works(self): diff --git a/tests/test_core/function_pod/test_function_pod_node.py b/tests/test_core/function_pod/test_function_pod_node.py index 00460bba..15e76c77 100644 --- a/tests/test_core/function_pod/test_function_pod_node.py +++ b/tests/test_core/function_pod/test_function_pod_node.py @@ -455,10 +455,12 @@ def test_meta_true_includes_packet_record_id(self, filled_node): assert result is not None assert constants.PACKET_RECORD_ID in result.column_names - def test_meta_true_includes_input_packet_hash(self, filled_node): + def test_meta_true_includes_cache_entry_hash(self, filled_node): + from orcapod.core.cached_function_pod import CachedFunctionPod + result = filled_node.get_all_records(columns={"meta": True}) assert result is not None - assert constants.INPUT_PACKET_HASH_COL in result.column_names + assert CachedFunctionPod.CACHE_ENTRY_HASH_COL in result.column_names def test_meta_true_still_has_data_columns(self, filled_node): result = filled_node.get_all_records(columns={"meta": True}) @@ -466,10 +468,12 @@ def test_meta_true_still_has_data_columns(self, filled_node): assert "id" in result.column_names assert "result" in result.column_names - def test_input_packet_hash_values_are_non_empty_strings(self, filled_node): + def test_cache_entry_hash_values_are_non_empty_strings(self, filled_node): + from orcapod.core.cached_function_pod import CachedFunctionPod + result = filled_node.get_all_records(columns={"meta": True}) assert result is not None - hashes = result.column(constants.INPUT_PACKET_HASH_COL).to_pylist() + hashes = result.column(CachedFunctionPod.CACHE_ENTRY_HASH_COL).to_pylist() assert all(isinstance(h, str) and len(h) > 0 for h in hashes) def test_packet_record_id_values_are_non_empty_strings(self, filled_node): @@ -665,7 +669,7 @@ def test_result_records_stored_under_result_suffix_path(self, double_pf): node.process_packet(tag, packet) db.flush() - result_path = node._packet_function.record_path + result_path = node._cached_function_pod.record_path assert result_path[-1] == "_result" or any( "_result" in part for part in result_path ) diff --git a/tests/test_pipeline/test_node_descriptors.py b/tests/test_pipeline/test_node_descriptors.py index bd2fb770..94d5c7a0 100644 --- a/tests/test_pipeline/test_node_descriptors.py +++ b/tests/test_pipeline/test_node_descriptors.py @@ -153,7 +153,7 @@ def _make_function_node_descriptor(self): }, "function_pod": pod.to_config(), "pipeline_path": list(node.pipeline_path), - "result_record_path": list(node._packet_function.record_path), + "result_record_path": list(node._cached_function_pod.record_path), } return node, descriptor, db diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 7c33dcd8..de9dc226 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -264,9 +264,9 @@ def test_function_database_none_uses_results_subfolder(self, pipeline_db): node = pipeline.compiled_nodes["adder"] assert isinstance(node, FunctionNode) - # The CachedPacketFunction's record_path should start with + # The CachedFunctionPod's record_path should start with # (pipeline_name, "_results", ...) - record_path = node._packet_function.record_path + record_path = node._cached_function_pod.record_path assert record_path[0] == "my_pipe" assert record_path[1] == "_results" @@ -289,8 +289,8 @@ def test_separate_function_database(self, pipeline_db, function_db): node = pipeline.compiled_nodes["adder"] assert isinstance(node, FunctionNode) - # The CachedPacketFunction should use function_db - assert node._packet_function._result_database is function_db + # The CachedFunctionPod should use function_db + assert node._cached_function_pod._result_database is function_db # --------------------------------------------------------------------------- From cb4550d2d5407513647f9503638ffc18130af951 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 06:48:56 +0000 Subject: [PATCH 4/8] fix(cached-function-pod): cache by packet hash only, not tag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CachedFunctionPod now caches by input packet hash only — the function output depends solely on the packet, not the tag. Tag-level uniqueness (tag + system tags + input packet hash) is handled by FunctionNode's pipeline record (add_pipeline_record / compute_pipeline_entry_id). iter_packets Phase 2 skip check now uses pipeline entry_ids (which include tag + system tags + packet hash) retrieved from the pipeline database, ensuring correct deduplication when the same packet appears with different tags/system_tags. Also: - Extracted compute_pipeline_entry_id() as a reusable method - Updated DESIGN_ISSUES CFP1: shared ResultCache refactor suggestion - Added TODO notes for match tier support (aligned with P6) Co-Authored-By: Claude Opus 4.6 (1M context) --- DESIGN_ISSUES.md | 27 +-- src/orcapod/core/cached_function_pod.py | 92 ++++------ src/orcapod/core/nodes/function_node.py | 162 +++++++++++++----- .../function_pod/test_cached_function_pod.py | 27 +-- .../function_pod/test_function_pod_node.py | 12 +- 5 files changed, 186 insertions(+), 134 deletions(-) diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 8cf66bff..e1acf4f9 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -315,22 +315,25 @@ which column groups (meta, source, system_tags) are returned. --- -## `src/orcapod/core/cached_function_pod.py` +## `src/orcapod/core/cached_function_pod.py` / `src/orcapod/core/packet_function.py` -### CFP1 — Consider single-column lookup optimization for pipeline record entry_id +### CFP1 — Extract shared result caching logic from CachedPacketFunction and CachedFunctionPod **Status:** open -**Severity:** low +**Severity:** medium + +`CachedPacketFunction` and `CachedFunctionPod` implement nearly identical result caching +logic: DB lookup by `INPUT_PACKET_HASH_COL`, conflict resolution by most-recent timestamp, +record storage with variation/execution/timestamp columns, and a `RESULT_COMPUTED_FLAG` +meta column. The match tier / matching policy design (P6) will also need to apply to both. -`CachedFunctionPod` and `FunctionNode.add_pipeline_record` both compute a combined hash -from `tag.as_table(columns={"system_tags": True}) + input_packet_hash`. The -`CachedFunctionPod` uses this as a single-column DB lookup key (`CACHE_ENTRY_HASH_COL`), -while the pipeline record stores it as `entry_id`. +This duplication means any future changes to caching behavior (e.g. implementing match +tiers, adding new stored columns, changing conflict resolution) must be applied in two +places. -Currently `FunctionNode.add_pipeline_record` recomputes this hash independently. Consider -benchmarking whether passing the already-computed `entry_hash` from `CachedFunctionPod` -through to `add_pipeline_record` would yield a meaningful speedup. The hash computation -involves Arrow table construction + semantic hashing, so avoiding the second computation -could be worthwhile for large pipelines. +**Suggested refactor:** Extract a `ResultCache` (or similar) class that owns the DB, +record path, lookup, store, and conflict resolution logic. Both `CachedPacketFunction` +and `CachedFunctionPod` would delegate to a shared `ResultCache` instance. The match tier +strategy (P6) would then be implemented once on `ResultCache`. --- diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 365fa0bb..2462ba05 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -28,25 +28,18 @@ class CachedFunctionPod(WrappedFunctionPod): """Pod-level caching wrapper that intercepts ``process_packet()``. - Unlike ``CachedPacketFunction`` (which caches at the ``call(packet)`` - level using only the packet content hash as the cache key), this - wrapper operates at the ``process_packet(tag, packet)`` level and - incorporates the tag (including system tags) and the packet content - hash into a single cache entry hash. + Caches at the ``process_packet(tag, packet)`` level using only the + **input packet content hash** as the cache key — the output of a + packet function depends solely on the packet, not the tag. - The cache entry hash is computed the same way as the pipeline record's - ``entry_id``: ``arrow_hasher.hash_table(tag_with_system_tags + input_packet_hash)``. - This ensures two rows with identical user tags but different system - tags (reflecting different source entries) are cached separately. + Tag-level provenance tracking (tag + system tags + packet hash) is + handled separately by ``FunctionNode.add_pipeline_record``. Storage format aligns with ``CachedPacketFunction``: each cached - record includes function variation data, execution data, the cache - entry hash, and a timestamp. + record includes function variation data, execution data, input packet + hash, and a timestamp. """ - # Column storing the combined hash of tag + system tags + input packet hash - CACHE_ENTRY_HASH_COL = f"{constants.META_PREFIX}cache_entry_hash" - # Meta column indicating whether the result was freshly computed RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" @@ -68,32 +61,11 @@ def record_path(self) -> tuple[str, ...]: """Return the path to the cached records in the result store.""" return self._record_path_prefix + self.uri - def _compute_entry_hash(self, tag: TagProtocol, packet: PacketProtocol) -> str: - """Compute a cache entry hash from tag (with system tags) and packet. - - The hash includes user-facing tag columns, system tag columns, and - the input packet content hash — matching the pipeline record's - entry_id computation. + def _lookup(self, input_packet: PacketProtocol) -> PacketProtocol | None: + """Look up a cached output packet by input packet content hash. Args: - tag: The tag associated with the packet. - packet: The input packet. - - Returns: - A hash string uniquely identifying this (tag, system_tags, packet) - combination. - """ - tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( - constants.INPUT_PACKET_HASH_COL, - pa.array([packet.content_hash().to_string()], type=pa.large_string()), - ) - return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() - - def _lookup(self, entry_hash: str) -> PacketProtocol | None: - """Look up a cached output packet by cache entry hash. - - Args: - entry_hash: The combined tag+system_tags+packet hash. + input_packet: The input packet whose content hash is used for lookup. Returns: The cached output packet, or ``None`` if not found. @@ -102,9 +74,15 @@ def _lookup(self, entry_hash: str) -> PacketProtocol | None: RECORD_ID_COL = "_record_id" + # TODO: add match based on match_tier if specified + # TODO: implement matching policy/strategy (see DESIGN_ISSUES P6) + constraints = { + constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string(), + } + result_table = self._result_database.get_records_with_column_value( self.record_path, - {self.CACHE_ENTRY_HASH_COL: entry_hash}, + constraints, record_id_column=RECORD_ID_COL, ) @@ -113,7 +91,8 @@ def _lookup(self, entry_hash: str) -> PacketProtocol | None: if result_table.num_rows > 1: logger.info( - "Pod-level cache: multiple records for entry hash, taking most recent" + "Pod-level cache: multiple records for input packet hash, " + "taking most recent" ) result_table = result_table.sort_by( [(constants.POD_TIMESTAMP, "descending")] @@ -121,7 +100,7 @@ def _lookup(self, entry_hash: str) -> PacketProtocol | None: record_id = result_table.to_pylist()[0][RECORD_ID_COL] result_table = result_table.drop_columns( - [RECORD_ID_COL, self.CACHE_ENTRY_HASH_COL] + [RECORD_ID_COL, constants.INPUT_PACKET_HASH_COL] ) return Packet( @@ -132,18 +111,16 @@ def _lookup(self, entry_hash: str) -> PacketProtocol | None: def _store( self, - entry_hash: str, input_packet: PacketProtocol, output_packet: PacketProtocol, ) -> None: """Store an output packet in the cache. Stores the output packet data alongside function variation data, - execution data, the cache entry hash, and a timestamp — matching - the column structure of ``CachedPacketFunction``. + execution data, input packet hash, and timestamp — matching the + column structure of ``CachedPacketFunction``. Args: - entry_hash: The combined tag+system_tags+packet hash. input_packet: The input packet (used for its content hash). output_packet: The computed output packet to store. """ @@ -170,11 +147,11 @@ def _store( ) i += 1 - # Add cache entry hash (position 0) + # Add input packet hash (position 0, same as CachedPacketFunction) data_table = data_table.add_column( 0, - self.CACHE_ENTRY_HASH_COL, - pa.array([entry_hash], type=pa.large_string()), + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), ) # Append timestamp @@ -199,11 +176,10 @@ def process_packet( ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a packet with pod-level caching. - The cache entry hash incorporates tag columns, system tag columns, - and the input packet content hash. On a cache hit, the stored - output packet is returned without invoking the inner pod's - computation. The output packet carries a ``RESULT_COMPUTED_FLAG`` - meta value: ``True`` if freshly computed, ``False`` if from cache. + The cache key is the input packet content hash only — the function + output depends solely on the packet, not the tag. The output + packet carries a ``RESULT_COMPUTED_FLAG`` meta value: ``True`` if + freshly computed, ``False`` if retrieved from cache. Args: tag: The tag associated with the packet. @@ -213,15 +189,14 @@ def process_packet( A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the inner function filters the packet out. """ - entry_hash = self._compute_entry_hash(tag, packet) - cached = self._lookup(entry_hash) + cached = self._lookup(packet) if cached is not None: logger.info("Pod-level cache hit") return tag, cached tag, output = self._function_pod.process_packet(tag, packet) if output is not None: - self._store(entry_hash, packet, output) + self._store(packet, output) output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output @@ -234,15 +209,14 @@ async def async_process_packet( actual computation uses the inner pod's ``async_process_packet`` for true async execution. """ - entry_hash = self._compute_entry_hash(tag, packet) - cached = self._lookup(entry_hash) + cached = self._lookup(packet) if cached is not None: logger.info("Pod-level cache hit") return tag, cached tag, output = await self._function_pod.async_process_packet(tag, packet) if output is not None: - self._store(entry_hash, packet, output) + self._store(packet, output) output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index a385787e..f451a59f 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -538,6 +538,27 @@ async def async_process_packet( else: return await self._function_pod.async_process_packet(tag, packet) + def compute_pipeline_entry_id( + self, tag: TagProtocol, input_packet: PacketProtocol + ) -> str: + """Compute a unique pipeline entry ID from tag + system tags + input packet hash. + + This ID uniquely identifies a (tag, system_tags, input_packet) combination + and is used as the record ID in the pipeline database. + + Args: + tag: The tag (including system tags). + input_packet: The input packet. + + Returns: + A hash string uniquely identifying this combination. + """ + tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() + def add_pipeline_record( self, tag: TagProtocol, @@ -555,14 +576,7 @@ def add_pipeline_record( - Input packet data context key - Whether the result was freshly computed or cached """ - # Compute entry hash from tag + system tags + input packet hash - tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( - constants.INPUT_PACKET_HASH_COL, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # Unique entry ID: combination of tags, system_tags, and input_packet hash - entry_id = self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() + entry_id = self.compute_pipeline_entry_id(tag, input_packet) # Check for existing entry existing_record = None @@ -710,28 +724,65 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: if self._cached_input_iterator is not None: input_iter = self._cached_input_iterator # --- Phase 1: yield already-computed results from the databases --- - existing = self.get_all_records(columns={"meta": True}) - computed_entry_hashes: set[str] = set() - if existing is not None and existing.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - entry_hash_col = CachedFunctionPod.CACHE_ENTRY_HASH_COL - computed_entry_hashes = set( - cast(list[str], existing.column(entry_hash_col).to_pylist()) + # Retrieve pipeline records with their entry_ids (record IDs) + # and join with result records to reconstruct (tag, output_packet). + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() ) - # Strip the entry hash column before yielding as stream - data_table = existing.drop([entry_hash_col]) - existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for i, (tag, packet) in enumerate(existing_stream.iter_packets()): - self._cached_output_packets[i] = (tag, packet) - yield tag, packet + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + # Collect pipeline entry_ids for Phase 2 skip check + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + # Drop internal columns before yielding as stream + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for i, (tag, packet) in enumerate( + existing_stream.iter_packets() + ): + self._cached_output_packets[i] = (tag, packet) + yield tag, packet # --- Phase 2: process only missing input packets --- + # Skip inputs whose pipeline entry_id (tag+system_tags+packet_hash) + # already exists in the pipeline database. next_idx = len(self._cached_output_packets) for tag, packet in input_iter: - entry_hash = self._cached_function_pod._compute_entry_hash( - tag, packet - ) - if entry_hash in computed_entry_hashes: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: continue tag, output_packet = self.process_packet(tag, packet) self._cached_output_packets[next_idx] = (tag, output_packet) @@ -955,21 +1006,50 @@ async def async_execute( if self._cached_function_pod is not None: # Two-phase async execution with DB backing # Phase 1: emit existing results from DB - existing = self.get_all_records(columns={"meta": True}) - computed_entry_hashes: set[str] = set() - if existing is not None and existing.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - entry_hash_col = CachedFunctionPod.CACHE_ENTRY_HASH_COL - computed_entry_hashes = set( - cast( - list[str], - existing.column(entry_hash_col).to_pylist(), + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", ) + .to_arrow() ) - data_table = existing.drop([entry_hash_col]) - existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - await output.send((tag, packet)) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) # Phase 2: process new packets concurrently sem = ( @@ -993,10 +1073,8 @@ async def process_one_db( async with asyncio.TaskGroup() as tg: async for tag, packet in inputs[0]: - entry_hash = self._cached_function_pod._compute_entry_hash( - tag, packet - ) - if entry_hash in computed_entry_hashes: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: continue if sem is not None: await sem.acquire() diff --git a/tests/test_core/function_pod/test_cached_function_pod.py b/tests/test_core/function_pod/test_cached_function_pod.py index ab984611..620a63ca 100644 --- a/tests/test_core/function_pod/test_cached_function_pod.py +++ b/tests/test_core/function_pod/test_cached_function_pod.py @@ -147,26 +147,28 @@ def test_second_call_does_not_add_new_records(self, cached_pod, cache_db): # --------------------------------------------------------------------------- -class TestTagAwareCaching: - def test_different_tags_same_packet_cached_separately(self, double_pod, cache_db): - """Same packet data with different tag values should be cached separately.""" +class TestCacheKeySemantics: + def test_same_packet_different_tags_is_cache_hit(self, double_pod, cache_db): + """Same packet data with different tags is a cache hit — the function + output depends only on the packet, not the tag.""" cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) - # Stream with tag=0, x=10 stream1 = _make_stream([{"id": 0, "x": 10}]) list(cached_pod.process(stream1).iter_packets()) - # Stream with tag=1, x=10 (same packet data, different tag) + # Same packet data, different tag — should be cache hit stream2 = _make_stream([{"id": 1, "x": 10}]) - list(cached_pod.process(stream2).iter_packets()) + results = list(cached_pod.process(stream2).iter_packets()) records = cache_db.get_all_records(cached_pod.record_path) assert records is not None - # Should have 2 records since tags differ - assert records.num_rows == 2 + # Only 1 record since the cache key is input packet hash only + assert records.num_rows == 1 + # But result is still correct + assert results[0][1].as_dict()["result"] == 20 - def test_same_tag_different_packet_cached_separately(self, double_pod, cache_db): - """Same tag value with different packet data should produce separate entries.""" + def test_different_packet_data_cached_separately(self, double_pod, cache_db): + """Different packet data should produce separate cache entries.""" cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) stream1 = _make_stream([{"id": 0, "x": 10}]) @@ -179,14 +181,13 @@ def test_same_tag_different_packet_cached_separately(self, double_pod, cache_db) assert records is not None assert records.num_rows == 2 - def test_same_tag_same_packet_is_cache_hit(self, double_pod, cache_db): - """Exact same tag + packet should be a cache hit (no new record).""" + def test_identical_input_is_cache_hit(self, double_pod, cache_db): + """Exact same input is a cache hit (no new record).""" cached_pod = CachedFunctionPod(double_pod, result_database=cache_db) stream1 = _make_stream([{"id": 0, "x": 10}]) list(cached_pod.process(stream1).iter_packets()) - # Identical stream stream2 = _make_stream([{"id": 0, "x": 10}]) list(cached_pod.process(stream2).iter_packets()) diff --git a/tests/test_core/function_pod/test_function_pod_node.py b/tests/test_core/function_pod/test_function_pod_node.py index 15e76c77..29bd5514 100644 --- a/tests/test_core/function_pod/test_function_pod_node.py +++ b/tests/test_core/function_pod/test_function_pod_node.py @@ -455,12 +455,10 @@ def test_meta_true_includes_packet_record_id(self, filled_node): assert result is not None assert constants.PACKET_RECORD_ID in result.column_names - def test_meta_true_includes_cache_entry_hash(self, filled_node): - from orcapod.core.cached_function_pod import CachedFunctionPod - + def test_meta_true_includes_input_packet_hash(self, filled_node): result = filled_node.get_all_records(columns={"meta": True}) assert result is not None - assert CachedFunctionPod.CACHE_ENTRY_HASH_COL in result.column_names + assert constants.INPUT_PACKET_HASH_COL in result.column_names def test_meta_true_still_has_data_columns(self, filled_node): result = filled_node.get_all_records(columns={"meta": True}) @@ -468,12 +466,10 @@ def test_meta_true_still_has_data_columns(self, filled_node): assert "id" in result.column_names assert "result" in result.column_names - def test_cache_entry_hash_values_are_non_empty_strings(self, filled_node): - from orcapod.core.cached_function_pod import CachedFunctionPod - + def test_input_packet_hash_values_are_non_empty_strings(self, filled_node): result = filled_node.get_all_records(columns={"meta": True}) assert result is not None - hashes = result.column(CachedFunctionPod.CACHE_ENTRY_HASH_COL).to_pylist() + hashes = result.column(constants.INPUT_PACKET_HASH_COL).to_pylist() assert all(isinstance(h, str) and len(h) > 0 for h in hashes) def test_packet_record_id_values_are_non_empty_strings(self, filled_node): From 94999e5c84ee11e481ce35dcd3e0655e39a430a7 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 06:54:30 +0000 Subject: [PATCH 5/8] test(function-node): add comprehensive pipeline + result cache tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New test file covering the dual-database caching architecture: - compute_pipeline_entry_id: determinism, tag/packet sensitivity - System tag awareness: identical user tags with different system tags produce different pipeline entry_ids - Result DB vs pipeline DB record counts: same packet/different tags → 1 result record, N pipeline records - Phase 1/2 with pipeline entry_ids: Phase 1 yields existing, Phase 2 skips matching entry_ids, processes only novel combinations - Same packet + new tag triggers Phase 2 (novel entry_id) even though CachedFunctionPod has a result cache hit - Pipeline records reference same result UUID for identical packets - Pipeline records include source columns but not data columns of input Co-Authored-By: Claude Opus 4.6 (1M context) --- .../test_function_node_caching.py | 389 ++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 tests/test_core/function_pod/test_function_node_caching.py diff --git a/tests/test_core/function_pod/test_function_node_caching.py b/tests/test_core/function_pod/test_function_node_caching.py new file mode 100644 index 00000000..d261669e --- /dev/null +++ b/tests/test_core/function_pod/test_function_node_caching.py @@ -0,0 +1,389 @@ +"""Tests for FunctionNode caching: pipeline DB vs result DB interaction. + +Covers: +- compute_pipeline_entry_id behavior +- Pipeline entry_id based Phase 2 skip (tag + system_tags + packet_hash) +- CachedFunctionPod result cache hit with novel pipeline entry_id +- Same packet data, different tags → 1 result record, N pipeline records +- System tag awareness in pipeline entry_id computation +- Phase 1 yields existing records, Phase 2 processes only novel entry_ids +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.cached_function_pod import CachedFunctionPod +from orcapod.core.datagrams import Packet, Tag +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.system_constants import constants + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_pod(): + pf = PythonPacketFunction(double, output_keys="result") + return FunctionPod(pf) + + +def _make_stream(rows: list[dict], tag_columns: list[str] = None) -> ArrowTableStream: + if tag_columns is None: + tag_columns = ["id"] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + return ArrowTableStream(table, tag_columns=tag_columns) + + +def _make_source_stream( + rows: list[dict], tag_columns: list[str] = None, source_id: str = "src_a" +) -> ArrowTableStream: + """Create a stream from an ArrowTableSource so it has system tag columns.""" + if tag_columns is None: + tag_columns = ["id"] + table = pa.table( + {k: pa.array([r[k] for r in rows], type=pa.int64()) for k in rows[0]} + ) + source = ArrowTableSource(table, tag_columns=tag_columns, source_id=source_id) + return source + + +def _make_node(stream, db=None): + pod = _make_pod() + if db is None: + db = InMemoryArrowDatabase() + return FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=db, + result_database=db, + ), db + + +# --------------------------------------------------------------------------- +# compute_pipeline_entry_id +# --------------------------------------------------------------------------- + + +class TestComputePipelineEntryId: + def test_returns_non_empty_string(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + packet = Packet({"x": 10}) + entry_id = node.compute_pipeline_entry_id(tag, packet) + assert isinstance(entry_id, str) + assert len(entry_id) > 0 + + def test_same_inputs_produce_same_id(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + packet = Packet({"x": 10}) + id1 = node.compute_pipeline_entry_id(tag, packet) + id2 = node.compute_pipeline_entry_id(tag, packet) + assert id1 == id2 + + def test_different_tags_produce_different_ids(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + packet = Packet({"x": 10}) + id_tag0 = node.compute_pipeline_entry_id(Tag({"id": 0}), packet) + id_tag1 = node.compute_pipeline_entry_id(Tag({"id": 1}), packet) + assert id_tag0 != id_tag1 + + def test_different_packets_produce_different_ids(self): + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + tag = Tag({"id": 0}) + id_x10 = node.compute_pipeline_entry_id(tag, Packet({"x": 10})) + id_x99 = node.compute_pipeline_entry_id(tag, Packet({"x": 99})) + assert id_x10 != id_x99 + + +# --------------------------------------------------------------------------- +# System tag awareness in entry_id +# --------------------------------------------------------------------------- + + +class TestSystemTagAwareness: + def test_same_tag_values_different_system_tags_produce_different_ids(self): + """Two tags with identical user values but different system tags + must produce different pipeline entry_ids.""" + stream = _make_stream([{"id": 0, "x": 10}]) + node, _ = _make_node(stream) + packet = Packet({"x": 10}) + + # Tags with same user value but different system tag columns + tag_a = Tag( + {"id": 0}, + system_tags={f"{constants.SYSTEM_TAG_PREFIX}source:abc": "row0"}, + ) + tag_b = Tag( + {"id": 0}, + system_tags={f"{constants.SYSTEM_TAG_PREFIX}source:xyz": "row0"}, + ) + + id_a = node.compute_pipeline_entry_id(tag_a, packet) + id_b = node.compute_pipeline_entry_id(tag_b, packet) + assert id_a != id_b + + +# --------------------------------------------------------------------------- +# Result DB vs Pipeline DB record counts +# --------------------------------------------------------------------------- + + +class TestResultVsPipelineRecordCounts: + def test_same_packet_different_tags_one_result_two_pipeline_records(self): + """Same packet data with different tags should produce: + - 1 result record (CachedFunctionPod caches by packet hash only) + - 2 pipeline records (different tag → different entry_id) + """ + rows = [{"id": 0, "x": 10}, {"id": 1, "x": 10}] + stream = _make_stream(rows) + node, db = _make_node(stream) + node.run() + + # Result DB: 1 record (same packet hash, second is cache hit) + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + # Pipeline DB: 2 records (different tags → different entry_ids) + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + def test_different_packets_same_tag_two_result_two_pipeline_records(self): + """Different packet data with same tag should produce: + - 2 result records (different packet hashes) + - 2 pipeline records (different packet hash → different entry_id) + """ + rows = [{"id": 0, "x": 10}, {"id": 0, "x": 20}] + stream = _make_stream(rows) + node, db = _make_node(stream) + node.run() + + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 2 + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + def test_identical_rows_one_result_one_pipeline_record(self): + """Identical (tag, packet) → 1 result record, 1 pipeline record.""" + # A single row — process once + stream = _make_stream([{"id": 0, "x": 10}]) + node, db = _make_node(stream) + node.run() + + result_records = node._cached_function_pod._result_database.get_all_records( + node._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 1 + + +# --------------------------------------------------------------------------- +# Phase 1/2 interaction with pipeline entry_ids +# --------------------------------------------------------------------------- + + +class TestPhase1Phase2PipelineEntryId: + def test_phase1_yields_existing_results(self): + """Phase 1 should yield all previously computed results from DB.""" + db = InMemoryArrowDatabase() + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Create a second node with the same stream and shared DB + stream2 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # All results should come from Phase 1 (DB), not recomputed + assert len(results) == 2 + + def test_phase2_processes_novel_entry_ids_only(self): + """Phase 2 should only process inputs whose pipeline entry_id + is not yet in the pipeline DB.""" + db = InMemoryArrowDatabase() + + # First run: process 2 rows + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Second run: 3 rows, 2 existing + 1 new + stream2 = _make_stream( + [{"id": 0, "x": 10}, {"id": 1, "x": 20}, {"id": 2, "x": 30}] + ) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Should yield 3 total: 2 from Phase 1 + 1 from Phase 2 + assert len(results) == 3 + # The new result should have the correct value + result_values = sorted(p.as_dict()["result"] for _, p in results) + assert result_values == [20, 40, 60] + + def test_same_packet_new_tag_triggers_phase2(self): + """Same packet data but new tag should trigger Phase 2 processing + because the pipeline entry_id (tag+system_tags+packet) is novel, + even though CachedFunctionPod has a cache hit for the packet.""" + db = InMemoryArrowDatabase() + + # First run: tag=0, x=10 + stream1 = _make_stream([{"id": 0, "x": 10}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + pipeline_count_after_first = db.get_all_records(node1.pipeline_path).num_rows + assert pipeline_count_after_first == 1 + + # Second run: tag=1, x=10 (same packet, different tag) + stream2 = _make_stream([{"id": 1, "x": 10}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Should yield 1 result from Phase 1 (tag=0) + 1 from Phase 2 (tag=1) + assert len(results) == 2 + + # Pipeline DB should now have 2 records + pipeline_records = db.get_all_records(node2.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + # Result DB should still have only 1 record (same packet hash) + result_records = node2._cached_function_pod._result_database.get_all_records( + node2._cached_function_pod.record_path + ) + assert result_records is not None + assert result_records.num_rows == 1 + + def test_all_existing_entry_ids_skipped_in_phase2(self): + """When all inputs already have pipeline records, Phase 2 + should not call process_packet at all.""" + db = InMemoryArrowDatabase() + + stream1 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Re-run with identical stream + stream2 = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 20}]) + node2, _ = _make_node(stream2, db=db) + + pipeline_count_before = db.get_all_records(node2.pipeline_path).num_rows + + results = list(node2.iter_packets()) + assert len(results) == 2 + + # No new pipeline records should be added + pipeline_count_after = db.get_all_records(node2.pipeline_path).num_rows + assert pipeline_count_after == pipeline_count_before + + +# --------------------------------------------------------------------------- +# CachedFunctionPod cache hit + novel pipeline entry +# --------------------------------------------------------------------------- + + +class TestResultCacheHitPipelineNovel: + def test_cached_result_reused_for_new_tag(self): + """When CachedFunctionPod has a cache hit (same packet hash) but + the pipeline entry_id is novel (different tag), the cached result + should be reused and a new pipeline record created.""" + db = InMemoryArrowDatabase() + + # Process tag=0, x=10 + stream1 = _make_stream([{"id": 0, "x": 10}]) + node1, _ = _make_node(stream1, db=db) + node1.run() + + # Process tag=1, x=10 — same packet, different tag + stream2 = _make_stream([{"id": 1, "x": 10}]) + node2, _ = _make_node(stream2, db=db) + results = list(node2.iter_packets()) + + # Both tags should produce the same result value + result_values = [p.as_dict()["result"] for _, p in results] + assert all(v == 20 for v in result_values) + + def test_pipeline_records_reference_same_result_uuid(self): + """Two pipeline records for the same packet (different tags) + should reference the same output packet UUID in the result DB.""" + db = InMemoryArrowDatabase() + + stream = _make_stream([{"id": 0, "x": 10}, {"id": 1, "x": 10}]) + node, _ = _make_node(stream, db=db) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + assert pipeline_records.num_rows == 2 + + # Both pipeline records should reference the same PACKET_RECORD_ID + record_ids = pipeline_records.column(constants.PACKET_RECORD_ID).to_pylist() + assert record_ids[0] == record_ids[1] + + +# --------------------------------------------------------------------------- +# Source columns in pipeline records +# --------------------------------------------------------------------------- + + +class TestPipelineRecordSourceColumns: + def test_pipeline_record_contains_source_columns(self): + """Pipeline records should include source columns of the input packet.""" + stream = _make_source_stream([{"id": 0, "x": 10}], source_id="my_source") + node, db = _make_node(stream) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + + source_cols = [ + c + for c in pipeline_records.column_names + if c.startswith(constants.SOURCE_PREFIX) + ] + assert len(source_cols) > 0 + + def test_pipeline_record_excludes_data_columns_of_input(self): + """Pipeline records should NOT include data columns of the input packet.""" + stream = _make_source_stream([{"id": 0, "x": 10}]) + node, db = _make_node(stream) + node.run() + + pipeline_records = db.get_all_records(node.pipeline_path) + assert pipeline_records is not None + + # "x" is the input packet data column — should not appear + assert "x" not in pipeline_records.column_names + assert "_input_x" not in pipeline_records.column_names From 5edb463b7288425a823a4aa0eaf6066f92d06374 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 07:08:40 +0000 Subject: [PATCH 6/8] refactor(caching): extract shared ResultCache from CachedPacketFunction and CachedFunctionPod MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New ResultCache class owns lookup, store, conflict resolution, and auto-flush logic. Both CachedPacketFunction and CachedFunctionPod delegate to a ResultCache instance. ResultCache.lookup accepts additional_constraints dict — the hook for future match tier support (P6). Default lookup matches on INPUT_PACKET_HASH_COL only; additional constraints can narrow the match (e.g. by function variation hash). Resolves DESIGN_ISSUES CFP1. Co-Authored-By: Claude Opus 4.6 (1M context) --- DESIGN_ISSUES.md | 15 +- src/orcapod/core/cached_function_pod.py | 179 +++----------- src/orcapod/core/packet_function.py | 144 +++-------- src/orcapod/core/result_cache.py | 226 ++++++++++++++++++ .../test_cached_packet_function.py | 4 +- 5 files changed, 306 insertions(+), 262 deletions(-) create mode 100644 src/orcapod/core/result_cache.py diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index e1acf4f9..9dcab19b 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -318,7 +318,7 @@ which column groups (meta, source, system_tags) are returned. ## `src/orcapod/core/cached_function_pod.py` / `src/orcapod/core/packet_function.py` ### CFP1 — Extract shared result caching logic from CachedPacketFunction and CachedFunctionPod -**Status:** open +**Status:** resolved **Severity:** medium `CachedPacketFunction` and `CachedFunctionPod` implement nearly identical result caching @@ -326,14 +326,11 @@ logic: DB lookup by `INPUT_PACKET_HASH_COL`, conflict resolution by most-recent record storage with variation/execution/timestamp columns, and a `RESULT_COMPUTED_FLAG` meta column. The match tier / matching policy design (P6) will also need to apply to both. -This duplication means any future changes to caching behavior (e.g. implementing match -tiers, adding new stored columns, changing conflict resolution) must be applied in two -places. - -**Suggested refactor:** Extract a `ResultCache` (or similar) class that owns the DB, -record path, lookup, store, and conflict resolution logic. Both `CachedPacketFunction` -and `CachedFunctionPod` would delegate to a shared `ResultCache` instance. The match tier -strategy (P6) would then be implemented once on `ResultCache`. +**Fix:** Extracted `ResultCache` class (`src/orcapod/core/result_cache.py`) that owns the DB, +record path, lookup (with `additional_constraints` for future match tiers), store, conflict +resolution, and auto-flush logic. Both `CachedPacketFunction` and `CachedFunctionPod` now +delegate to a `ResultCache` instance. The match tier strategy (P6) can be implemented once +on `ResultCache.lookup` via `additional_constraints`. --- diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 2462ba05..76d2a27d 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from orcapod.core.function_pod import WrappedFunctionPod +from orcapod.core.result_cache import ResultCache from orcapod.protocols.core_protocols import ( FunctionPodProtocol, PacketProtocol, @@ -14,13 +14,9 @@ TagProtocol, ) from orcapod.protocols.database_protocols import ArrowDatabaseProtocol -from orcapod.system_constants import constants -from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa -else: - pa = LazyModule("pyarrow") logger = logging.getLogger(__name__) @@ -35,13 +31,12 @@ class CachedFunctionPod(WrappedFunctionPod): Tag-level provenance tracking (tag + system tags + packet hash) is handled separately by ``FunctionNode.add_pipeline_record``. - Storage format aligns with ``CachedPacketFunction``: each cached - record includes function variation data, execution data, input packet - hash, and a timestamp. + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedPacketFunction``). """ - # Meta column indicating whether the result was freshly computed - RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG def __init__( self, @@ -52,124 +47,22 @@ def __init__( **kwargs, ) -> None: super().__init__(function_pod, **kwargs) - self._result_database = result_database self._record_path_prefix = record_path_prefix - self._auto_flush = auto_flush + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=auto_flush, + ) + + @property + def _result_database(self) -> ArrowDatabaseProtocol: + """The underlying result database (for FunctionNode access).""" + return self._cache.result_database @property def record_path(self) -> tuple[str, ...]: """Return the path to the cached records in the result store.""" - return self._record_path_prefix + self.uri - - def _lookup(self, input_packet: PacketProtocol) -> PacketProtocol | None: - """Look up a cached output packet by input packet content hash. - - Args: - input_packet: The input packet whose content hash is used for lookup. - - Returns: - The cached output packet, or ``None`` if not found. - """ - from orcapod.core.datagrams import Packet - - RECORD_ID_COL = "_record_id" - - # TODO: add match based on match_tier if specified - # TODO: implement matching policy/strategy (see DESIGN_ISSUES P6) - constraints = { - constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string(), - } - - result_table = self._result_database.get_records_with_column_value( - self.record_path, - constraints, - record_id_column=RECORD_ID_COL, - ) - - if result_table is None or result_table.num_rows == 0: - return None - - if result_table.num_rows > 1: - logger.info( - "Pod-level cache: multiple records for input packet hash, " - "taking most recent" - ) - result_table = result_table.sort_by( - [(constants.POD_TIMESTAMP, "descending")] - ).take([0]) - - record_id = result_table.to_pylist()[0][RECORD_ID_COL] - result_table = result_table.drop_columns( - [RECORD_ID_COL, constants.INPUT_PACKET_HASH_COL] - ) - - return Packet( - result_table, - record_id=record_id, - meta_info={self.RESULT_COMPUTED_FLAG: False}, - ) - - def _store( - self, - input_packet: PacketProtocol, - output_packet: PacketProtocol, - ) -> None: - """Store an output packet in the cache. - - Stores the output packet data alongside function variation data, - execution data, input packet hash, and timestamp — matching the - column structure of ``CachedPacketFunction``. - - Args: - input_packet: The input packet (used for its content hash). - output_packet: The computed output packet to store. - """ - data_table = output_packet.as_table(columns={"source": True, "context": True}) - - pf = self._function_pod.packet_function - - # Add function variation data columns - i = 0 - for k, v in pf.get_function_variation_data().items(): - data_table = data_table.add_column( - i, - f"{constants.PF_VARIATION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - # Add execution data columns - for k, v in pf.get_execution_data().items(): - data_table = data_table.add_column( - i, - f"{constants.PF_EXECUTION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - # Add input packet hash (position 0, same as CachedPacketFunction) - data_table = data_table.add_column( - 0, - constants.INPUT_PACKET_HASH_COL, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # Append timestamp - timestamp = datetime.now(timezone.utc) - data_table = data_table.append_column( - constants.POD_TIMESTAMP, - pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), - ) - - self._result_database.add_record( - self.record_path, - output_packet.datagram_id, - data_table, - skip_duplicates=False, - ) - - if self._auto_flush: - self._result_database.flush() + return self._cache.record_path def process_packet( self, tag: TagProtocol, packet: PacketProtocol @@ -189,14 +82,20 @@ def process_packet( A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the inner function filters the packet out. """ - cached = self._lookup(packet) + cached = self._cache.lookup(packet) if cached is not None: logger.info("Pod-level cache hit") return tag, cached tag, output = self._function_pod.process_packet(tag, packet) if output is not None: - self._store(packet, output) + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output @@ -209,38 +108,30 @@ async def async_process_packet( actual computation uses the inner pod's ``async_process_packet`` for true async execution. """ - cached = self._lookup(packet) + cached = self._cache.lookup(packet) if cached is not None: logger.info("Pod-level cache hit") return tag, cached tag, output = await self._function_pod.async_process_packet(tag, packet) if output is not None: - self._store(packet, output) + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output def get_all_cached_outputs( self, include_system_columns: bool = False ) -> "pa.Table | None": - """Return all cached records from the result store for this pod. - - Args: - include_system_columns: If True, include system columns - (e.g. record_id) in the result. - - Returns: - A PyArrow table of cached results, or ``None`` if empty. - """ - record_id_column = ( - constants.PACKET_RECORD_ID if include_system_columns else None - ) - result_table = self._result_database.get_all_records( - self.record_path, record_id_column=record_id_column + """Return all cached records from the result store for this pod.""" + return self._cache.get_all_records( + include_system_columns=include_system_columns ) - if result_table is None or result_table.num_rows == 0: - return None - return result_table def process( self, *streams: StreamProtocol, label: str | None = None diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index 851a66b7..f12e6ecc 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -20,6 +20,7 @@ get_function_components, get_function_signature, ) +from orcapod.core.result_cache import ResultCache from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol from orcapod.protocols.core_protocols.executor import ( PacketFunctionExecutorProtocol, @@ -724,10 +725,14 @@ async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | No class CachedPacketFunction(PacketFunctionWrapper): - """Wrapper around a PacketFunctionProtocol that caches results for identical input packets.""" + """Wrapper around a PacketFunctionProtocol that caches results for identical input packets. - # cloumn name containing indication of whether the result was computed - RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedFunctionPod``). + """ + + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG def __init__( self, @@ -739,16 +744,20 @@ def __init__( super().__init__(packet_function, **kwargs) self._result_database = result_database self._record_path_prefix = record_path_prefix - self._auto_flush = True + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=True, + ) def set_auto_flush(self, on: bool = True) -> None: """Set auto-flush behavior. If True, the database flushes after each record.""" - self._auto_flush = on + self._cache.set_auto_flush(on) @property def record_path(self) -> tuple[str, ...]: """Return the path to the record in the result store.""" - return self._record_path_prefix + self.uri + return self._cache.record_path def call( self, @@ -757,20 +766,22 @@ def call( skip_cache_lookup: bool = False, skip_cache_insert: bool = False, ) -> PacketProtocol | None: - # execution_engine_hash = execution_engine.name if execution_engine else "default" output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") - # lookup stored result for the input packet - output_packet = self.get_cached_output_for_packet(packet) + output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") if output_packet is None: output_packet = self._packet_function.call(packet) if output_packet is not None: if not skip_cache_insert: - self.record_packet(packet, output_packet) - # add meta column to indicate that this was computed + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) @@ -788,14 +799,19 @@ async def async_call( output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") - output_packet = self.get_cached_output_for_packet(packet) + output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") if output_packet is None: output_packet = await self._packet_function.async_call(packet) if output_packet is not None: if not skip_cache_insert: - self.record_packet(packet, output_packet) + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) @@ -811,45 +827,7 @@ def get_cached_output_for_packet( Returns: The cached output packet, or ``None`` if no entry was found. """ - - # get all records with matching the input packet hash - # TODO: add match based on match_tier if specified - - # TODO: implement matching policy/strategy - constraints = { - constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string() - } - - RECORD_ID_COLUMN = "_record_id" - result_table = self._result_database.get_records_with_column_value( - self.record_path, - constraints, - record_id_column=RECORD_ID_COLUMN, - ) - - if result_table is None or result_table.num_rows == 0: - return None - - if result_table.num_rows > 1: - logger.info( - f"Performing conflict resolution for multiple records for {input_packet.content_hash().display_name()}" - ) - result_table = result_table.sort_by( - [(constants.POD_TIMESTAMP, "descending")] - ).take([0]) - - # extract the record_id column - record_id = result_table.to_pylist()[0][RECORD_ID_COLUMN] - result_table = result_table.drop_columns( - [RECORD_ID_COLUMN, constants.INPUT_PACKET_HASH_COL] - ) - - # note that data context will be loaded from the result store - return Packet( - result_table, - record_id=record_id, - meta_info={self.RESULT_COMPUTED_FLAG: False}, - ) + return self._cache.lookup(input_packet) def record_packet( self, @@ -858,54 +836,13 @@ def record_packet( skip_duplicates: bool = False, ) -> PacketProtocol: """Record the output packet against the input packet in the result store.""" - - # TODO: consider incorporating execution_engine_opts into the record - data_table = output_packet.as_table(columns={"source": True, "context": True}) - - i = 0 - for k, v in self.get_function_variation_data().items(): - # add the tiered pod ID to the data table - data_table = data_table.add_column( - i, - f"{constants.PF_VARIATION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - for k, v in self.get_execution_data().items(): - # add the tiered pod ID to the data table - data_table = data_table.add_column( - i, - f"{constants.PF_EXECUTION_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - i += 1 - - # add the input packet hash as a column - data_table = data_table.add_column( - 0, - constants.INPUT_PACKET_HASH_COL, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # add computation timestamp - timestamp = datetime.now(timezone.utc) - data_table = data_table.append_column( - constants.POD_TIMESTAMP, - pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), - ) - - self._result_database.add_record( - self.record_path, - output_packet.datagram_id, # output packet datagram ID (uuid) is used as a unique identification - data_table, + self._cache.store( + input_packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), skip_duplicates=skip_duplicates, ) - - if self._auto_flush: - self._result_database.flush() - - # TODO: make store return retrieved table return output_packet def get_all_cached_outputs( @@ -920,13 +857,6 @@ def get_all_cached_outputs( Returns: A PyArrow table of cached results, or ``None`` if empty. """ - record_id_column = ( - constants.PACKET_RECORD_ID if include_system_columns else None + return self._cache.get_all_records( + include_system_columns=include_system_columns ) - result_table = self._result_database.get_all_records( - self.record_path, record_id_column=record_id_column - ) - if result_table is None or result_table.num_rows == 0: - return None - - return result_table diff --git a/src/orcapod/core/result_cache.py b/src/orcapod/core/result_cache.py new file mode 100644 index 00000000..8afa8b8f --- /dev/null +++ b/src/orcapod/core/result_cache.py @@ -0,0 +1,226 @@ +"""ResultCache — shared result caching logic for CachedPacketFunction and CachedFunctionPod. + +Owns the database, record path, lookup (with match strategy), store, +conflict resolution, and auto-flush behavior. Both ``CachedPacketFunction`` +and ``CachedFunctionPod`` delegate to a ``ResultCache`` instance. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from orcapod.protocols.core_protocols import PacketProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class ResultCache: + """Shared result caching backed by an ``ArrowDatabaseProtocol``. + + Provides lookup (by input packet hash + optional additional constraints), + store (output data + variation/execution metadata + timestamp), conflict + resolution (most-recent-timestamp wins), and auto-flush. + + The match strategy is extensible: the default lookup matches on + ``INPUT_PACKET_HASH_COL`` only, but callers can supply additional + constraints (e.g. function variation columns) to narrow the match. + This is the hook for future match-tier support (see DESIGN_ISSUES P6). + + Args: + result_database: The database to store/retrieve cached results. + record_path: The record path tuple for scoping records in the database. + auto_flush: If True, flush the database after each store operation. + """ + + # Meta column indicating whether the result was freshly computed + RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + + def __init__( + self, + result_database: ArrowDatabaseProtocol, + record_path: tuple[str, ...], + auto_flush: bool = True, + ) -> None: + self._result_database = result_database + self._record_path = record_path + self._auto_flush = auto_flush + + @property + def result_database(self) -> ArrowDatabaseProtocol: + """The underlying database.""" + return self._result_database + + @property + def record_path(self) -> tuple[str, ...]: + """The record path for scoping records in the database.""" + return self._record_path + + def set_auto_flush(self, on: bool = True) -> None: + """Set auto-flush behavior.""" + self._auto_flush = on + + def lookup( + self, + input_packet: PacketProtocol, + additional_constraints: dict[str, str] | None = None, + ) -> PacketProtocol | None: + """Look up a cached output packet for *input_packet*. + + The default match is by ``INPUT_PACKET_HASH_COL`` only. + *additional_constraints* can narrow the match further (e.g. by + function variation hash for stricter cache invalidation). + + If multiple records match, the most recent (by timestamp) wins. + + Args: + input_packet: The input packet whose content hash is the + primary lookup key. + additional_constraints: Optional extra column-value pairs to + include in the lookup query. + + Returns: + The cached output packet with ``RESULT_COMPUTED_FLAG: False`` + in its meta, or ``None`` if no match was found. + """ + from orcapod.core.datagrams import Packet + + RECORD_ID_COL = "_record_id" + + constraints: dict[str, str] = { + constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string(), + } + if additional_constraints: + constraints.update(additional_constraints) + + result_table = self._result_database.get_records_with_column_value( + self._record_path, + constraints, + record_id_column=RECORD_ID_COL, + ) + + if result_table is None or result_table.num_rows == 0: + return None + + if result_table.num_rows > 1: + logger.info( + "Cache conflict resolution: %d records for constraints %s, " + "taking most recent", + result_table.num_rows, + list(constraints.keys()), + ) + result_table = result_table.sort_by( + [(constants.POD_TIMESTAMP, "descending")] + ).take([0]) + + record_id = result_table.to_pylist()[0][RECORD_ID_COL] + # Drop lookup columns from the returned packet + drop_cols = [RECORD_ID_COL] + [ + c for c in constraints if c in result_table.column_names + ] + result_table = result_table.drop_columns(drop_cols) + + return Packet( + result_table, + record_id=record_id, + meta_info={self.RESULT_COMPUTED_FLAG: False}, + ) + + def store( + self, + input_packet: PacketProtocol, + output_packet: PacketProtocol, + variation_data: dict[str, Any], + execution_data: dict[str, Any], + skip_duplicates: bool = False, + ) -> None: + """Store an output packet in the cache. + + Stores the output packet data alongside function variation data, + execution data, input packet hash, and a timestamp. + + Args: + input_packet: The input packet (used for its content hash). + output_packet: The computed output packet to store. + variation_data: Function variation metadata (e.g. function name, + signature hash, content hash, git hash). + execution_data: Execution environment metadata (e.g. python + version, execution context). + skip_duplicates: If True, silently skip if a record with the + same ID already exists. + """ + data_table = output_packet.as_table(columns={"source": True, "context": True}) + + # Add function variation data columns + i = 0 + for k, v in variation_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_VARIATION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add execution data columns + for k, v in execution_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_EXECUTION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add input packet hash (position 0) + data_table = data_table.add_column( + 0, + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + + # Append timestamp + timestamp = datetime.now(timezone.utc) + data_table = data_table.append_column( + constants.POD_TIMESTAMP, + pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), + ) + + self._result_database.add_record( + self._record_path, + output_packet.datagram_id, + data_table, + skip_duplicates=skip_duplicates, + ) + + if self._auto_flush: + self._result_database.flush() + + def get_all_records( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store. + + Args: + include_system_columns: If True, include system columns + (e.g. record_id) in the result. + + Returns: + A PyArrow table of cached results, or ``None`` if empty. + """ + record_id_column = ( + constants.PACKET_RECORD_ID if include_system_columns else None + ) + result_table = self._result_database.get_all_records( + self._record_path, record_id_column=record_id_column + ) + if result_table is None or result_table.num_rows == 0: + return None + return result_table diff --git a/tests/test_core/packet_function/test_cached_packet_function.py b/tests/test_core/packet_function/test_cached_packet_function.py index ba968dd0..9a4218a1 100644 --- a/tests/test_core/packet_function/test_cached_packet_function.py +++ b/tests/test_core/packet_function/test_cached_packet_function.py @@ -96,11 +96,11 @@ def test_record_path_prefix_prepended(self, inner_pf, db): assert cpf.record_path == ("org", "project") + inner_pf.uri def test_auto_flush_true_by_default(self, cached_pf): - assert cached_pf._auto_flush is True + assert cached_pf._cache._auto_flush is True def test_set_auto_flush_false(self, cached_pf): cached_pf.set_auto_flush(False) - assert cached_pf._auto_flush is False + assert cached_pf._cache._auto_flush is False # --------------------------------------------------------------------------- From 7776cff58eb78539cdb0cbd31c570670d423930c Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 07:11:04 +0000 Subject: [PATCH 7/8] test(result-cache): add direct unit tests for ResultCache - Lookup: miss on empty DB, miss on different packet, hit returns correct result with RESULT_COMPUTED_FLAG=False, different record paths are isolated - Conflict resolution: most recent timestamp wins - Additional constraints: non-matching constraint filters out, matching constraint (e.g. function_name) returns result - Store: input_packet_hash, variation, execution, timestamp, output data columns all present - Auto flush: default true, set_auto_flush, constructor param - get_all_records: empty returns None, includes/excludes system columns Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_core/test_result_cache.py | 306 +++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 tests/test_core/test_result_cache.py diff --git a/tests/test_core/test_result_cache.py b/tests/test_core/test_result_cache.py new file mode 100644 index 00000000..cda78270 --- /dev/null +++ b/tests/test_core/test_result_cache.py @@ -0,0 +1,306 @@ +"""Tests for ResultCache — shared result caching logic. + +Covers: +- Lookup: cache miss, cache hit, conflict resolution (most recent wins) +- Store: variation/execution columns, input packet hash, timestamp +- additional_constraints for narrowing match +- auto_flush behavior +- get_all_records +""" + +from __future__ import annotations + +from typing import Any + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams import Packet +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.result_cache import ResultCache +from orcapod.databases import InMemoryArrowDatabase +from orcapod.system_constants import constants + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def _make_pf() -> PythonPacketFunction: + return PythonPacketFunction(double, output_keys="result") + + +def _make_cache( + db=None, record_path=("test",) +) -> tuple[ResultCache, InMemoryArrowDatabase]: + if db is None: + db = InMemoryArrowDatabase() + cache = ResultCache(result_database=db, record_path=record_path) + return cache, db + + +def _compute_and_store( + cache: ResultCache, pf: PythonPacketFunction, input_packet: Packet +): + """Helper: compute output and store in cache.""" + output = pf.direct_call(input_packet) + assert output is not None + cache.store( + input_packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + return output + + +# --------------------------------------------------------------------------- +# Lookup +# --------------------------------------------------------------------------- + + +class TestLookupMiss: + def test_empty_db_returns_none(self): + cache, _ = _make_cache() + packet = Packet({"x": 10}) + assert cache.lookup(packet) is None + + def test_different_packet_returns_none(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + assert cache.lookup(Packet({"x": 99})) is None + + +class TestLookupHit: + def test_returns_cached_result(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.as_dict()["result"] == 20 + + def test_sets_computed_flag_false(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.get_meta_value(ResultCache.RESULT_COMPUTED_FLAG) is False + + def test_same_packet_different_record_path_is_miss(self): + db = InMemoryArrowDatabase() + cache_a = ResultCache(result_database=db, record_path=("path_a",)) + cache_b = ResultCache(result_database=db, record_path=("path_b",)) + pf = _make_pf() + input_pkt = Packet({"x": 10}) + + output = pf.direct_call(input_pkt) + cache_a.store( + input_pkt, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + + assert cache_a.lookup(input_pkt) is not None + assert cache_b.lookup(input_pkt) is None + + +class TestConflictResolution: + def test_most_recent_wins(self): + import time + + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + + # Store first result + _compute_and_store(cache, pf, input_pkt) + time.sleep(0.01) # ensure different timestamp + + # Store a second result for the same input (simulating recomputation) + output2 = pf.direct_call(input_pkt) + cache.store( + input_pkt, + output2, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + + # Lookup should return the most recent + cached = cache.lookup(input_pkt) + assert cached is not None + assert cached.datagram_id == output2.datagram_id + + +# --------------------------------------------------------------------------- +# Additional constraints +# --------------------------------------------------------------------------- + + +class TestAdditionalConstraints: + def test_narrower_match_filters_results(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + # Lookup with a constraint that doesn't match any stored column value + result = cache.lookup( + input_pkt, + additional_constraints={ + f"{constants.PF_VARIATION_PREFIX}function_name": "nonexistent" + }, + ) + assert result is None + + def test_matching_constraint_returns_result(self): + cache, _ = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + # Lookup with a constraint that matches + result = cache.lookup( + input_pkt, + additional_constraints={ + f"{constants.PF_VARIATION_PREFIX}function_name": "double" + }, + ) + assert result is not None + assert result.as_dict()["result"] == 20 + + +# --------------------------------------------------------------------------- +# Store +# --------------------------------------------------------------------------- + + +class TestStore: + def test_stores_input_packet_hash_column(self): + cache, db = _make_cache() + pf = _make_pf() + input_pkt = Packet({"x": 10}) + _compute_and_store(cache, pf, input_pkt) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert constants.INPUT_PACKET_HASH_COL in records.column_names + + def test_stores_variation_columns(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + variation_cols = [ + c + for c in records.column_names + if c.startswith(constants.PF_VARIATION_PREFIX) + ] + assert len(variation_cols) > 0 + + def test_stores_execution_columns(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + exec_cols = [ + c + for c in records.column_names + if c.startswith(constants.PF_EXECUTION_PREFIX) + ] + assert len(exec_cols) > 0 + + def test_stores_timestamp(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert constants.POD_TIMESTAMP in records.column_names + + def test_stores_output_data(self): + cache, db = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = db.get_all_records(cache.record_path) + assert records is not None + assert "result" in records.column_names + assert records.column("result").to_pylist() == [20] + + +# --------------------------------------------------------------------------- +# Auto flush +# --------------------------------------------------------------------------- + + +class TestAutoFlush: + def test_auto_flush_true_by_default(self): + cache, _ = _make_cache() + assert cache._auto_flush is True + + def test_set_auto_flush_false(self): + cache, _ = _make_cache() + cache.set_auto_flush(False) + assert cache._auto_flush is False + + def test_auto_flush_false_in_constructor(self): + db = InMemoryArrowDatabase() + cache = ResultCache(result_database=db, record_path=("t",), auto_flush=False) + assert cache._auto_flush is False + + +# --------------------------------------------------------------------------- +# get_all_records +# --------------------------------------------------------------------------- + + +class TestGetAllRecords: + def test_empty_returns_none(self): + cache, _ = _make_cache() + assert cache.get_all_records() is None + + def test_returns_stored_records(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + _compute_and_store(cache, pf, Packet({"x": 20})) + + records = cache.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_include_system_columns_adds_record_id(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = cache.get_all_records(include_system_columns=True) + assert records is not None + assert constants.PACKET_RECORD_ID in records.column_names + + def test_exclude_system_columns_by_default(self): + cache, _ = _make_cache() + pf = _make_pf() + _compute_and_store(cache, pf, Packet({"x": 10})) + + records = cache.get_all_records(include_system_columns=False) + assert records is not None + assert constants.PACKET_RECORD_ID not in records.column_names From f7fe3e0adfd68fde0fa117de28e2f357ad975369 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 14 Mar 2026 07:29:26 +0000 Subject: [PATCH 8/8] fix: address Copilot review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - PythonFunctionExecutorProtocol now extends PacketFunctionExecutorProtocol (fixes type bound mismatch and ensures supports()/supported_function_type_ids are always available) - PacketFunctionBase._executor typed as E | None (type-safe access to execute_callable on PythonPacketFunction without casts) - LocalExecutor.execute_callable handles nested event loops (mirrors PythonPacketFunction._call_async_function_sync pattern) - Pipeline._apply_execution_engine always calls with_options() per node — executor decides what to copy vs share - Fixed stale docstring (pod_cache_database: "input packet content hash" not "tag+packet hash") - Fixed type annotations: list[str] | None in test helpers, list[Any] for mock executor call lists - Updated pipeline tests to check node's executor (not original mock) Co-Authored-By: Claude Opus 4.6 (1M context) --- src/orcapod/core/executors/local.py | 15 ++++++++++- src/orcapod/core/function_pod.py | 2 +- src/orcapod/core/packet_function.py | 6 ++--- src/orcapod/pipeline/graph.py | 13 +++++---- .../protocols/core_protocols/executor.py | 27 +++---------------- .../test_function_node_caching.py | 6 +++-- tests/test_pipeline/test_pipeline.py | 26 ++++++++++-------- 7 files changed, 47 insertions(+), 48 deletions(-) diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index f56242e4..fc97e477 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -48,9 +48,22 @@ def execute_callable( executor_options: dict[str, Any] | None = None, ) -> Any: if inspect.iscoroutinefunction(fn): - return asyncio.run(fn(**kwargs)) + return self._run_async_sync(fn, kwargs) return fn(**kwargs) + @staticmethod + def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: + """Run an async function synchronously, handling nested event loops.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(fn(**kwargs)) + else: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(1) as pool: + return pool.submit(lambda: asyncio.run(fn(**kwargs))).result() + async def async_execute_callable( self, fn: Callable[..., Any], diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 5c96a05c..5a1db639 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -667,7 +667,7 @@ def function_pod( (wraps the packet function in ``CachedPacketFunction``). pod_cache_database: Optional database for pod-level caching (wraps the pod in ``CachedFunctionPod``, which caches at the - ``process_packet(tag, packet)`` level using tag+packet hash). + ``process_packet`` level using input packet content hash). executor: Optional executor for running the packet function. **kwargs: Forwarded to ``PythonPacketFunction``. diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index f12e6ecc..bd26ea11 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -141,7 +141,7 @@ def __init__( super().__init__(label=label, data_context=data_context, config=config) self._active = True self._version = version - self._executor: PacketFunctionExecutorProtocol | None = None + self._executor: E | None = None # Parse version string to extract major and minor versions # 0.5.2 -> 0 and 5.2, 1.3rc -> 1 and 3rc @@ -244,12 +244,12 @@ def get_execution_data(self) -> dict[str, Any]: # ==================== Executor ==================== @property - def executor(self) -> PacketFunctionExecutorProtocol | None: + def executor(self) -> E | None: """Return the executor used to run this packet function, or ``None`` for direct execution.""" return self._executor @executor.setter - def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + def executor(self, executor: E | None) -> None: """Set or clear the executor for this packet function. Delegates to ``set_executor`` for validation. diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index 6ee216d0..a800b7ed 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -425,9 +425,11 @@ def _apply_execution_engine( ) -> None: """Apply *execution_engine* to every ``FunctionNode`` in the pipeline. - If *execution_engine_opts* is non-empty, ``engine.with_options`` - is called to produce a configured executor; otherwise the engine - instance is used directly. + Each node receives its own executor instance via + ``engine.with_options(**opts)`` — even when *opts* is empty. + The executor's ``with_options`` implementation decides which + components to copy vs share (e.g. connection handles may be + shared while per-node state is copied). Args: execution_engine: Executor to apply (must implement @@ -441,14 +443,11 @@ def _apply_execution_engine( ) opts = execution_engine_opts or {} - configured_executor = ( - execution_engine.with_options(**opts) if opts else execution_engine - ) for node in self._node_graph.nodes: if not isinstance(node, FunctionNode): continue - node.executor = configured_executor + node.executor = execution_engine.with_options(**opts) logger.debug( "Applied execution engine %r to node %r (opts=%r)", type(execution_engine).__name__, diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index f3055805..2260898a 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -84,26 +84,15 @@ def get_execution_data(self) -> dict[str, Any]: @runtime_checkable -class PythonFunctionExecutorProtocol(Protocol): +class PythonFunctionExecutorProtocol(PacketFunctionExecutorProtocol, Protocol): """Executor protocol for Python callable-based packet functions. - Unlike ``PacketFunctionExecutorProtocol`` which operates on - (packet_function, packet) pairs, this protocol operates on raw - Python callables — the executor receives the function and its - keyword arguments directly. The packet function handles + Extends ``PacketFunctionExecutorProtocol`` with callable-level + execution methods. The executor receives the raw Python function + and its keyword arguments directly — the packet function handles packet construction/deconstruction around the executor call. """ - @property - def executor_type_id(self) -> str: - """Unique identifier for this executor type.""" - ... - - @property - def supports_concurrent_execution(self) -> bool: - """Whether this executor can run multiple calls concurrently.""" - ... - def execute_callable( self, fn: Callable[..., Any], @@ -140,11 +129,3 @@ async def async_execute_callable( The raw return value of *fn*. """ ... - - def with_options(self, **opts: Any) -> Self: - """Return a **new** executor instance with the given options merged in.""" - ... - - def get_execution_data(self) -> dict[str, Any]: - """Return metadata describing the execution environment.""" - ... diff --git a/tests/test_core/function_pod/test_function_node_caching.py b/tests/test_core/function_pod/test_function_node_caching.py index d261669e..0a0403f3 100644 --- a/tests/test_core/function_pod/test_function_node_caching.py +++ b/tests/test_core/function_pod/test_function_node_caching.py @@ -39,7 +39,9 @@ def _make_pod(): return FunctionPod(pf) -def _make_stream(rows: list[dict], tag_columns: list[str] = None) -> ArrowTableStream: +def _make_stream( + rows: list[dict], tag_columns: list[str] | None = None +) -> ArrowTableStream: if tag_columns is None: tag_columns = ["id"] table = pa.table( @@ -49,7 +51,7 @@ def _make_stream(rows: list[dict], tag_columns: list[str] = None) -> ArrowTableS def _make_source_stream( - rows: list[dict], tag_columns: list[str] = None, source_id: str = "src_a" + rows: list[dict], tag_columns: list[str] | None = None, source_id: str = "src_a" ) -> ArrowTableStream: """Create a stream from an ArrowTableSource so it has system tag columns.""" if tag_columns is None: diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index de9dc226..99ffd956 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1146,8 +1146,8 @@ class _MockExecutor(PacketFunctionExecutorBase): def __init__(self, opts: dict[str, Any] | None = None) -> None: self.opts: dict[str, Any] = opts or {} - self.sync_calls: list[PacketProtocol] = [] - self.async_calls: list[PacketProtocol] = [] + self.sync_calls: list[Any] = [] + self.async_calls: list[Any] = [] @property def executor_type_id(self) -> str: @@ -1204,7 +1204,7 @@ def test_engine_is_applied_to_all_function_nodes(self, pipeline_db): pipeline.run(execution_engine=mock) - assert pipeline.doubler.executor is mock + assert isinstance(pipeline.doubler.executor, _MockExecutor) def test_engine_without_config_triggers_async_mode(self, pipeline_db): """No config + execution_engine → async channels mode by default.""" @@ -1219,8 +1219,10 @@ def test_engine_without_config_triggers_async_mode(self, pipeline_db): pipeline.run(execution_engine=mock) - assert len(mock.async_calls) > 0 - assert len(mock.sync_calls) == 0 + # Each node gets its own copy; check the node's executor + node_executor = pipeline.doubler.executor + assert len(node_executor.async_calls) > 0 + assert len(node_executor.sync_calls) == 0 def test_explicit_sync_config_overrides_async_default(self, pipeline_db): """Explicit config=PipelineConfig(executor=SYNCHRONOUS) takes priority @@ -1241,8 +1243,9 @@ def test_explicit_sync_config_overrides_async_default(self, pipeline_db): config=PipelineConfig(executor=ExecutorType.SYNCHRONOUS), ) - assert len(mock.sync_calls) > 0 - assert len(mock.async_calls) == 0 + node_executor = pipeline.doubler.executor + assert len(node_executor.sync_calls) > 0 + assert len(node_executor.async_calls) == 0 def test_pipeline_opts_applied_via_with_options(self, pipeline_db): """Pipeline-level execution_engine_opts are applied via with_options.""" @@ -1263,8 +1266,8 @@ def test_pipeline_opts_applied_via_with_options(self, pipeline_db): # Executor should have been created with the pipeline opts assert pipeline.doubler.executor.opts.get("num_cpus") == 4 - def test_no_opts_uses_engine_directly(self, pipeline_db): - """Without execution_engine_opts, the engine itself is assigned (no with_options).""" + def test_no_opts_produces_per_node_copy(self, pipeline_db): + """Without execution_engine_opts, each node gets its own executor copy.""" src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) pf = PythonPacketFunction(double_value, output_keys="result") pod = FunctionPod(packet_function=pf) @@ -1276,8 +1279,9 @@ def test_no_opts_uses_engine_directly(self, pipeline_db): pipeline.run(execution_engine=mock) - # Without opts, the original mock executor is assigned directly - assert pipeline.doubler.executor is mock + # Each node gets a copy via with_options(), not the original + assert pipeline.doubler.executor is not mock + assert isinstance(pipeline.doubler.executor, _MockExecutor) assert pipeline.doubler.executor.opts == {}