From 40d28e5e6e15317b4318ee841a223998fe0e5fa2 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Mar 2026 23:02:45 +0000 Subject: [PATCH 1/7] feat(pipeline): add SourceCacheMode with OFF mode for source nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add SourceCacheMode enum (FULL/OFF) to control whether source nodes persist their output to the cache database during pipeline execution. - FULL (default): current behavior — materializes tags + packets to DB - OFF: pass-through — source data flows directly to downstream nodes without any database interaction Pipeline accepts source_cache_mode parameter and passes it through compile() to PersistentSourceNode. When OFF, run() is a no-op, iter_packets()/as_table() delegate directly to the wrapped stream, and get_all_records() returns None. https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- src/orcapod/pipeline/graph.py | 5 +- src/orcapod/pipeline/nodes.py | 54 ++++++++++---- src/orcapod/types.py | 15 ++++ tests/test_pipeline/test_pipeline.py | 107 +++++++++++++++++++++++++++ 4 files changed, 167 insertions(+), 14 deletions(-) diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index c6538c01..286ff4b2 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -10,7 +10,7 @@ from orcapod.pipeline.nodes import PersistentSourceNode from orcapod.protocols import core_protocols as cp from orcapod.protocols import database_protocols as dbp -from orcapod.types import PipelineConfig +from orcapod.types import PipelineConfig, SourceCacheMode from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: @@ -87,12 +87,14 @@ def __init__( function_database: dbp.ArrowDatabaseProtocol | None = None, tracker_manager: cp.TrackerManagerProtocol | None = None, auto_compile: bool = True, + source_cache_mode: SourceCacheMode = SourceCacheMode.FULL, ) -> None: super().__init__(tracker_manager=tracker_manager) self._name = (name,) if isinstance(name, str) else tuple(name) self._pipeline_database = pipeline_database self._function_database = function_database self._pipeline_path_prefix = self._name + self._source_cache_mode = source_cache_mode self._nodes: dict[str, GraphNode] = {} self._persistent_node_map: dict[str, GraphNode] = {} self._node_graph: "nx.DiGraph | None" = None @@ -181,6 +183,7 @@ def compile(self) -> None: stream=stream, cache_database=self._pipeline_database, cache_path_prefix=self._pipeline_path_prefix, + source_cache_mode=self._source_cache_mode, ) persistent_node_map[node_hash] = persistent_node else: diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index a8e7d505..a46d5ae3 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -11,7 +11,7 @@ from orcapod.core.streams.arrow_table_stream import ArrowTableStream from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol from orcapod.protocols.database_protocols import ArrowDatabaseProtocol -from orcapod.types import ColumnConfig +from orcapod.types import ColumnConfig, SourceCacheMode from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: @@ -23,8 +23,7 @@ class PersistentSourceNode(SourceNode): - """ - DB-backed wrapper around any stream, used by ``Pipeline.compile()`` + """DB-backed wrapper around any stream, used by ``Pipeline.compile()`` to cache leaf stream data. Extends ``SourceNode`` (which delegates identity/schema to the wrapped @@ -37,6 +36,9 @@ class PersistentSourceNode(SourceNode): Cache path structure:: cache_path_prefix / source / node:{content_hash} + + When ``source_cache_mode`` is ``OFF``, the node delegates directly to + the wrapped stream without any database interaction. """ HASH_COLUMN_NAME = "_record_hash" @@ -49,6 +51,7 @@ def __init__( label: str | None = None, data_context: str | contexts.DataContext | None = None, config: Config | None = None, + source_cache_mode: SourceCacheMode = SourceCacheMode.FULL, ) -> None: super().__init__( stream=stream, @@ -58,8 +61,13 @@ def __init__( ) self._cache_database = cache_database self._cache_path_prefix = cache_path_prefix + self._source_cache_mode = source_cache_mode self._cached_stream: ArrowTableStream | None = None + @property + def source_cache_mode(self) -> SourceCacheMode: + return self._source_cache_mode + # ------------------------------------------------------------------------- # Cache path # ------------------------------------------------------------------------- @@ -77,8 +85,7 @@ def cache_path(self) -> tuple[str, ...]: # ------------------------------------------------------------------------- def _build_cached_stream(self) -> ArrowTableStream: - """ - Materialize the wrapped stream, store rows in the cache DB + """Materialize the wrapped stream, store rows in the cache DB (deduped by per-row hash), and return the cached table as an ``ArrowTableStream``. """ @@ -117,7 +124,9 @@ def _build_cached_stream(self) -> ArrowTableStream: return ArrowTableStream(all_records, tag_columns=tag_keys) def _ensure_stream(self) -> None: - """Build the cached stream on first access.""" + """Build the cached stream on first access (FULL mode only).""" + if self._source_cache_mode == SourceCacheMode.OFF: + return if self._cached_stream is None: self._cached_stream = self._build_cached_stream() self._update_modified_time() @@ -127,10 +136,15 @@ def _ensure_stream(self) -> None: # ------------------------------------------------------------------------- def run(self) -> None: - """Eagerly populate the cache with live stream data.""" + """Eagerly populate the cache with live stream data. + + In ``OFF`` mode this is a no-op — data flows through without caching. + """ self._ensure_stream() def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + if self._source_cache_mode == SourceCacheMode.OFF: + return self.stream.iter_packets() self._ensure_stream() assert self._cached_stream is not None return self._cached_stream.iter_packets() @@ -141,6 +155,8 @@ def as_table( columns: ColumnConfig | dict[str, Any] | None = None, all_info: bool = False, ) -> "pa.Table": + if self._source_cache_mode == SourceCacheMode.OFF: + return self.stream.as_table(columns=columns, all_info=all_info) self._ensure_stream() assert self._cached_stream is not None return self._cached_stream.as_table(columns=columns, all_info=all_info) @@ -150,15 +166,27 @@ async def async_execute( inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], ) -> None: - """Materialize to cache DB, then push cached rows to the output channel.""" + """Materialize to cache DB, then push cached rows to the output channel. + + In ``OFF`` mode, pushes live stream rows directly. + """ try: - self._ensure_stream() - assert self._cached_stream is not None - for tag, packet in self._cached_stream.iter_packets(): - await output.send((tag, packet)) + if self._source_cache_mode == SourceCacheMode.OFF: + for tag, packet in self.stream.iter_packets(): + await output.send((tag, packet)) + else: + self._ensure_stream() + assert self._cached_stream is not None + for tag, packet in self._cached_stream.iter_packets(): + await output.send((tag, packet)) finally: await output.close() def get_all_records(self) -> "pa.Table | None": - """Retrieve all stored records from the cache database.""" + """Retrieve all stored records from the cache database. + + Returns ``None`` in ``OFF`` mode (no records are stored). + """ + if self._source_cache_mode == SourceCacheMode.OFF: + return None return self._cache_database.get_all_records(self.cache_path) diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 54627506..6d10e30a 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -339,6 +339,21 @@ class CacheMode(Enum): REPLAY = "replay" +class SourceCacheMode(Enum): + """Controls source node caching behaviour in a compiled pipeline. + + Attributes: + FULL: Materialize the entire stream (tags + packets + source info) + into the cache database. Supports deduplication and append-only + accumulation across runs. Default. + OFF: No cache writes. The source streams data directly to + downstream nodes without any database interaction. + """ + + FULL = "full" + OFF = "off" + + @dataclass(frozen=True, slots=True) class ColumnConfig: """ diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 38d1f4bb..da165113 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -24,6 +24,7 @@ from orcapod.databases import InMemoryArrowDatabase from orcapod.pipeline import PersistentSourceNode, Pipeline from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol +from orcapod.types import SourceCacheMode # --------------------------------------------------------------------------- # Helpers @@ -1235,3 +1236,109 @@ def test_per_node_opts_override_pipeline_opts(self, pipeline_db): # Node opts win: executor should have been created with num_cpus=2 assert pipeline.doubler.executor.opts.get("num_cpus") == 2 + + +# --------------------------------------------------------------------------- +# Tests: SourceCacheMode +# --------------------------------------------------------------------------- + + +class TestSourceCacheModeOff: + """Verify that source_cache_mode=OFF bypasses database interaction.""" + + def test_compile_passes_cache_mode_to_source_nodes(self, pipeline_db): + src_a, src_b = _make_two_sources() + pipeline = Pipeline( + name="test_pipe", + pipeline_database=pipeline_db, + source_cache_mode=SourceCacheMode.OFF, + ) + with pipeline: + _ = Join()(src_a, src_b) + + source_nodes = [ + n + for n in pipeline._node_graph.nodes() + if isinstance(n, PersistentSourceNode) + ] + assert len(source_nodes) == 2 + for sn in source_nodes: + assert sn.source_cache_mode == SourceCacheMode.OFF + + def test_off_mode_no_db_writes(self, pipeline_db): + """In OFF mode, run() should not write anything to the cache DB.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pipeline = Pipeline( + name="test_pipe", + pipeline_database=pipeline_db, + source_cache_mode=SourceCacheMode.OFF, + ) + with pipeline: + MapPackets({"value": "val"})(src, label="mapper") + + pipeline.run() + + # Find the PersistentSourceNode + source_nodes = [ + n + for n in pipeline._node_graph.nodes() + if isinstance(n, PersistentSourceNode) + ] + assert len(source_nodes) == 1 + sn = source_nodes[0] + + # No records should be in the DB + assert sn.get_all_records() is None + + def test_off_mode_data_flows_through(self, pipeline_db): + """In OFF mode, downstream nodes should still receive data.""" + src_a, src_b = _make_two_sources() + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(packet_function=pf) + + pipeline = Pipeline( + name="test_pipe", + pipeline_database=pipeline_db, + source_cache_mode=SourceCacheMode.OFF, + ) + with pipeline: + joined = Join()(src_a, src_b) + pod(joined, label="adder") + + pipeline.run() + + # Function node should have computed results + records = pipeline.adder.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_full_mode_writes_to_db(self, pipeline_db): + """Sanity check: FULL mode (default) does write to DB.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pipeline = Pipeline( + name="test_pipe", + pipeline_database=pipeline_db, + source_cache_mode=SourceCacheMode.FULL, + ) + with pipeline: + MapPackets({"value": "val"})(src, label="mapper") + + pipeline.run() + + source_nodes = [ + n + for n in pipeline._node_graph.nodes() + if isinstance(n, PersistentSourceNode) + ] + assert len(source_nodes) == 1 + sn = source_nodes[0] + + # Records should be in the DB + records = sn.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_default_is_full(self, pipeline_db): + """Pipeline defaults to FULL source cache mode.""" + pipeline = Pipeline(name="test", pipeline_database=pipeline_db) + assert pipeline._source_cache_mode == SourceCacheMode.FULL From 5d5fa72c6c17a06701923ace0af10303a0e4c2c0 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Mar 2026 23:29:00 +0000 Subject: [PATCH 2/7] refactor(sources): decouple source caching from pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move source caching from pipeline-level (PersistentSourceNode) to source-level (CachedSource). Pipeline.compile() now wraps leaf streams in plain SourceNode (thin graph vertex wrapper) instead of PersistentSourceNode, eliminating double-caching when composing pipelines. Key changes: - Remove PersistentSourceNode and SourceCacheMode from pipeline - Pipeline.compile() wraps leaf streams in SourceNode (no caching) - Rename PersistentSource → CachedSource for clarity - Add RootSource.cached() convenience method - Update all tests and demo Source caching is now a source-level concern: cached = source.cached(cache_database=db) # or cached = CachedSource(source, cache_database=db) https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- demo_pipeline.py | 17 +- src/orcapod/core/sources/__init__.py | 4 +- src/orcapod/core/sources/base.py | 29 +++ ...{persistent_source.py => cached_source.py} | 34 +-- src/orcapod/pipeline/__init__.py | 2 - src/orcapod/pipeline/graph.py | 26 +-- src/orcapod/pipeline/nodes.py | 195 +----------------- src/orcapod/types.py | 14 -- ...istent_source.py => test_cached_source.py} | 120 ++++++----- tests/test_core/test_caching_integration.py | 68 +++--- tests/test_pipeline/test_pipeline.py | 184 +++-------------- 11 files changed, 208 insertions(+), 485 deletions(-) rename src/orcapod/core/sources/{persistent_source.py => cached_source.py} (88%) rename tests/test_core/sources/{test_persistent_source.py => test_cached_source.py} (76%) diff --git a/demo_pipeline.py b/demo_pipeline.py index c365530b..9b16f22d 100644 --- a/demo_pipeline.py +++ b/demo_pipeline.py @@ -23,7 +23,8 @@ from orcapod.core.packet_function import PythonPacketFunction from orcapod.core.sources import ArrowTableSource from orcapod.databases import DeltaTableDatabase, InMemoryArrowDatabase -from orcapod.pipeline import Pipeline, PersistentSourceNode +from orcapod.core.nodes import SourceNode +from orcapod.pipeline import Pipeline # --------------------------------------------------------------------------- @@ -96,10 +97,10 @@ def categorize(risk: float) -> str: for name, node in pipeline.compiled_nodes.items(): print(f" {name}: {type(node).__name__}") -print("\n Source nodes (PersistentSourceNode):") +print("\n Source nodes (SourceNode):") for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - print(f" cache_path: {n.cache_path}") + if isinstance(n, SourceNode): + print(f" label: {n.label}, stream: {type(n.stream).__name__}") # --- Access nodes by label --- print(f"\n pipeline.join_data -> {type(pipeline.join_data).__name__}") @@ -127,12 +128,6 @@ def categorize(risk: float) -> str: print(f" {cat_table.to_pandas()[['patient_id', 'category']].to_string(index=False)}") # --- Show what's in the database --- -print(f"\n Source records in DB:") -for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - records = n.get_all_records() - print(f" {n.cache_path[-1]}: {records.num_rows} rows") - fn_records = pipeline.compute_risk.get_all_records() print(f" Function records (compute_risk): {fn_records.num_rows} rows") @@ -320,7 +315,7 @@ def categorize(risk: float) -> str: print("=" * 70) print(""" Pipeline wraps ALL nodes as persistent variants automatically: - - Leaf streams -> PersistentSourceNode (DB-backed cache) + - Leaf streams -> SourceNode (graph vertex wrapper, no caching) - Operator calls -> PersistentOperatorNode (DB-backed cache) - Function pod calls -> PersistentFunctionNode (DB-backed cache) diff --git a/src/orcapod/core/sources/__init__.py b/src/orcapod/core/sources/__init__.py index 0cb810e0..ccb0e782 100644 --- a/src/orcapod/core/sources/__init__.py +++ b/src/orcapod/core/sources/__init__.py @@ -1,24 +1,24 @@ from .base import RootSource from .arrow_table_source import ArrowTableSource +from .cached_source import CachedSource from .csv_source import CSVSource from .data_frame_source import DataFrameSource from .delta_table_source import DeltaTableSource from .derived_source import DerivedSource from .dict_source import DictSource from .list_source import ListSource -from .persistent_source import PersistentSource from .source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry __all__ = [ "RootSource", "ArrowTableSource", + "CachedSource", "CSVSource", "DataFrameSource", "DeltaTableSource", "DerivedSource", "DictSource", "ListSource", - "PersistentSource", "SourceRegistry", "GLOBAL_SOURCE_REGISTRY", ] diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py index d1f213a2..60689a9e 100644 --- a/src/orcapod/core/sources/base.py +++ b/src/orcapod/core/sources/base.py @@ -137,3 +137,32 @@ def producer(self) -> None: def upstreams(self) -> tuple[StreamProtocol, ...]: """Sources have no upstream dependencies.""" return () + + # ------------------------------------------------------------------------- + # Convenience — caching + # ------------------------------------------------------------------------- + + def cached( + self, + cache_database: Any, + cache_path_prefix: tuple[str, ...] = (), + **kwargs: Any, + ) -> "RootSource": + """Return a ``CachedSource`` wrapping this source. + + Args: + cache_database: Database to store cached records in. + cache_path_prefix: Path prefix for the cache table. + **kwargs: Additional keyword arguments passed to ``CachedSource``. + + Returns: + A ``CachedSource`` that caches this source's output. + """ + from orcapod.core.sources.cached_source import CachedSource + + return CachedSource( + source=self, + cache_database=cache_database, + cache_path_prefix=cache_path_prefix, + **kwargs, + ) diff --git a/src/orcapod/core/sources/persistent_source.py b/src/orcapod/core/sources/cached_source.py similarity index 88% rename from src/orcapod/core/sources/persistent_source.py rename to src/orcapod/core/sources/cached_source.py index e53d3218..1e635ed1 100644 --- a/src/orcapod/core/sources/persistent_source.py +++ b/src/orcapod/core/sources/cached_source.py @@ -21,25 +21,29 @@ logger = logging.getLogger(__name__) -class PersistentSource(RootSource): - """ - DB-backed wrapper around a RootSource that caches every packet. +class CachedSource(RootSource): + """DB-backed wrapper around a ``RootSource`` that caches every packet. - Implements StreamProtocol transparently so downstream consumers + Implements ``StreamProtocol`` transparently so downstream consumers are unaware of caching. Cache table is scoped to the source's ``content_hash()`` — each unique source gets its own table. - Behavior - -------- - - Cache is **always on**. - - On first access, live source data is stored in the cache table - (deduped by per-row content hash). - - Returns the union of all cached data (cumulative across runs). - - Semantic guarantee - ------------------ - The cache is a correct cumulative record. The union of cache + live - packets is the full set of data ever available from that source. + Behavior: + - Cache is **always on**. + - On first access, live source data is stored in the cache table + (deduped by per-row content hash). + - Returns the union of all cached data (cumulative across runs). + + Semantic guarantee: + The cache is a correct cumulative record. The union of cache + live + packets is the full set of data ever available from that source. + + Example:: + + source = ArrowTableSource(table, tag_columns=["id"]) + cached = CachedSource(source, cache_database=db) + # or equivalently: + cached = source.cached(cache_database=db) """ HASH_COLUMN_NAME = "_record_hash" diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 7495be39..278acf33 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,9 +1,7 @@ from .graph import Pipeline -from .nodes import PersistentSourceNode from .orchestrator import AsyncPipelineOrchestrator __all__ = [ "AsyncPipelineOrchestrator", "Pipeline", - "PersistentSourceNode", ] diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index 286ff4b2..feda71d9 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -5,12 +5,11 @@ import tempfile from typing import TYPE_CHECKING, Any -from orcapod.core.nodes import GraphNode +from orcapod.core.nodes import GraphNode, SourceNode from orcapod.core.tracker import GraphTracker -from orcapod.pipeline.nodes import PersistentSourceNode from orcapod.protocols import core_protocols as cp from orcapod.protocols import database_protocols as dbp -from orcapod.types import PipelineConfig, SourceCacheMode +from orcapod.types import PipelineConfig from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: @@ -57,10 +56,14 @@ class Pipeline(GraphTracker): recorded as non-persistent nodes (same as ``GraphTracker``). On context exit, ``compile()`` replaces every node with its persistent variant: - - Leaf streams → ``PersistentSourceNode`` + - Leaf streams → ``SourceNode`` (thin wrapper for graph vertex) - Function pod invocations → ``PersistentFunctionNode`` - Operator invocations → ``PersistentOperatorNode`` + Source caching is not a pipeline concern — sources that need caching + should be wrapped in a ``CachedSource`` before being used in the + pipeline. + All persistent nodes share the same ``pipeline_database`` and use ``pipeline_name`` as path prefix, scoping their cache tables. @@ -87,14 +90,12 @@ def __init__( function_database: dbp.ArrowDatabaseProtocol | None = None, tracker_manager: cp.TrackerManagerProtocol | None = None, auto_compile: bool = True, - source_cache_mode: SourceCacheMode = SourceCacheMode.FULL, ) -> None: super().__init__(tracker_manager=tracker_manager) self._name = (name,) if isinstance(name, str) else tuple(name) self._pipeline_database = pipeline_database self._function_database = function_database self._pipeline_path_prefix = self._name - self._source_cache_mode = source_cache_mode self._nodes: dict[str, GraphNode] = {} self._persistent_node_map: dict[str, GraphNode] = {} self._node_graph: "nx.DiGraph | None" = None @@ -141,7 +142,7 @@ def compile(self) -> None: Walks the graph in topological order and creates: - - ``PersistentSourceNode`` for every leaf stream + - ``SourceNode`` for every leaf stream - ``PersistentFunctionNode`` for every function pod invocation - ``PersistentOperatorNode`` for every operator invocation @@ -177,14 +178,9 @@ def compile(self) -> None: continue if node_hash not in self._node_lut: - # -- Leaf stream: wrap in PersistentSourceNode -- + # -- Leaf stream: wrap in SourceNode -- stream = self._upstreams[node_hash] - persistent_node = PersistentSourceNode( - stream=stream, - cache_database=self._pipeline_database, - cache_path_prefix=self._pipeline_path_prefix, - source_cache_mode=self._source_cache_mode, - ) + persistent_node = SourceNode(stream=stream) persistent_node_map[node_hash] = persistent_node else: node = self._node_lut[node_hash] @@ -263,7 +259,7 @@ def compile(self) -> None: continue attrs = self._hash_graph.nodes[node_hash] if not attrs.get("node_type"): - if isinstance(node, PersistentSourceNode): + if isinstance(node, SourceNode): attrs["node_type"] = "source" elif isinstance(node, PersistentFunctionNode): attrs["node_type"] = "function" diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index a46d5ae3..5abc288c 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -1,192 +1,3 @@ -from __future__ import annotations - -import logging -from collections.abc import Iterator, Sequence -from typing import TYPE_CHECKING, Any - -from orcapod import contexts -from orcapod.channels import ReadableChannel, WritableChannel -from orcapod.config import Config -from orcapod.core.nodes import SourceNode -from orcapod.core.streams.arrow_table_stream import ArrowTableStream -from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol -from orcapod.protocols.database_protocols import ArrowDatabaseProtocol -from orcapod.types import ColumnConfig, SourceCacheMode -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -logger = logging.getLogger(__name__) - - -class PersistentSourceNode(SourceNode): - """DB-backed wrapper around any stream, used by ``Pipeline.compile()`` - to cache leaf stream data. - - Extends ``SourceNode`` (which delegates identity/schema to the wrapped - stream) and adds: - - - Materialization of the stream's output into a cache database - - Per-row deduplication via content hash - - Cached ``ArrowTableStream`` for downstream consumption - - Cache path structure:: - - cache_path_prefix / source / node:{content_hash} - - When ``source_cache_mode`` is ``OFF``, the node delegates directly to - the wrapped stream without any database interaction. - """ - - HASH_COLUMN_NAME = "_record_hash" - - def __init__( - self, - stream: StreamProtocol, - cache_database: ArrowDatabaseProtocol, - cache_path_prefix: tuple[str, ...] = (), - label: str | None = None, - data_context: str | contexts.DataContext | None = None, - config: Config | None = None, - source_cache_mode: SourceCacheMode = SourceCacheMode.FULL, - ) -> None: - super().__init__( - stream=stream, - label=label, - data_context=data_context, - config=config, - ) - self._cache_database = cache_database - self._cache_path_prefix = cache_path_prefix - self._source_cache_mode = source_cache_mode - self._cached_stream: ArrowTableStream | None = None - - @property - def source_cache_mode(self) -> SourceCacheMode: - return self._source_cache_mode - - # ------------------------------------------------------------------------- - # Cache path - # ------------------------------------------------------------------------- - - @property - def cache_path(self) -> tuple[str, ...]: - """Cache table path, scoped to the wrapped stream's content hash.""" - return self._cache_path_prefix + ( - "source", - f"node:{self.stream.content_hash().to_string()}", - ) - - # ------------------------------------------------------------------------- - # Caching logic - # ------------------------------------------------------------------------- - - def _build_cached_stream(self) -> ArrowTableStream: - """Materialize the wrapped stream, store rows in the cache DB - (deduped by per-row hash), and return the cached table as an - ``ArrowTableStream``. - """ - live_table = self.stream.as_table(columns={"source": True, "system_tags": True}) - - # Per-row content hashes for dedup - arrow_hasher = self.data_context.arrow_hasher - record_hashes: list[str] = [] - for batch in live_table.to_batches(): - for i in range(len(batch)): - record_hashes.append( - arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() - ) - - live_with_hash = live_table.add_column( - 0, - self.HASH_COLUMN_NAME, - pa.array(record_hashes, type=pa.large_string()), - ) - - self._cache_database.add_records( - self.cache_path, - live_with_hash, - record_id_column=self.HASH_COLUMN_NAME, - skip_duplicates=True, - ) - self._cache_database.flush() - - # Load all cached records (union of current + prior runs) - all_records = self._cache_database.get_all_records(self.cache_path) - assert all_records is not None, ( - "Cache should contain records after storing live data." - ) - - tag_keys = self.stream.keys()[0] - return ArrowTableStream(all_records, tag_columns=tag_keys) - - def _ensure_stream(self) -> None: - """Build the cached stream on first access (FULL mode only).""" - if self._source_cache_mode == SourceCacheMode.OFF: - return - if self._cached_stream is None: - self._cached_stream = self._build_cached_stream() - self._update_modified_time() - - # ------------------------------------------------------------------------- - # Stream interface overrides - # ------------------------------------------------------------------------- - - def run(self) -> None: - """Eagerly populate the cache with live stream data. - - In ``OFF`` mode this is a no-op — data flows through without caching. - """ - self._ensure_stream() - - def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: - if self._source_cache_mode == SourceCacheMode.OFF: - return self.stream.iter_packets() - self._ensure_stream() - assert self._cached_stream is not None - return self._cached_stream.iter_packets() - - def as_table( - self, - *, - columns: ColumnConfig | dict[str, Any] | None = None, - all_info: bool = False, - ) -> "pa.Table": - if self._source_cache_mode == SourceCacheMode.OFF: - return self.stream.as_table(columns=columns, all_info=all_info) - self._ensure_stream() - assert self._cached_stream is not None - return self._cached_stream.as_table(columns=columns, all_info=all_info) - - async def async_execute( - self, - inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], - output: WritableChannel[tuple[TagProtocol, PacketProtocol]], - ) -> None: - """Materialize to cache DB, then push cached rows to the output channel. - - In ``OFF`` mode, pushes live stream rows directly. - """ - try: - if self._source_cache_mode == SourceCacheMode.OFF: - for tag, packet in self.stream.iter_packets(): - await output.send((tag, packet)) - else: - self._ensure_stream() - assert self._cached_stream is not None - for tag, packet in self._cached_stream.iter_packets(): - await output.send((tag, packet)) - finally: - await output.close() - - def get_all_records(self) -> "pa.Table | None": - """Retrieve all stored records from the cache database. - - Returns ``None`` in ``OFF`` mode (no records are stored). - """ - if self._source_cache_mode == SourceCacheMode.OFF: - return None - return self._cache_database.get_all_records(self.cache_path) +# Pipeline-specific node types are defined here. +# PersistentFunctionNode and PersistentOperatorNode live in core/nodes/. +# SourceNode (thin graph vertex wrapper) lives in core/nodes/source_node.py. diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 6d10e30a..82d066c3 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -339,20 +339,6 @@ class CacheMode(Enum): REPLAY = "replay" -class SourceCacheMode(Enum): - """Controls source node caching behaviour in a compiled pipeline. - - Attributes: - FULL: Materialize the entire stream (tags + packets + source info) - into the cache database. Supports deduplication and append-only - accumulation across runs. Default. - OFF: No cache writes. The source streams data directly to - downstream nodes without any database interaction. - """ - - FULL = "full" - OFF = "off" - @dataclass(frozen=True, slots=True) class ColumnConfig: diff --git a/tests/test_core/sources/test_persistent_source.py b/tests/test_core/sources/test_cached_source.py similarity index 76% rename from tests/test_core/sources/test_persistent_source.py rename to tests/test_core/sources/test_cached_source.py index af0f22dc..ae2d2dd2 100644 --- a/tests/test_core/sources/test_persistent_source.py +++ b/tests/test_core/sources/test_cached_source.py @@ -1,5 +1,5 @@ """ -Tests for PersistentSource covering: +Tests for CachedSource covering: - Construction and transparent StreamProtocol implementation - Cache path scoped to source's content_hash - Cumulative caching: data from prior runs is preserved @@ -17,7 +17,7 @@ import pyarrow as pa import pytest -from orcapod.core.sources import ArrowTableSource, PersistentSource +from orcapod.core.sources import ArrowTableSource, CachedSource from orcapod.core.streams import ArrowTableStream from orcapod.databases import InMemoryArrowDatabase from orcapod.protocols.core_protocols import StreamProtocol @@ -55,29 +55,29 @@ def db(): # --------------------------------------------------------------------------- -class TestPersistentSourceConstruction: +class TestCachedSourceConstruction: def test_source_id_delegated(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.source_id == simple_source.source_id def test_stream_protocol_conformance(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert isinstance(ps, StreamProtocol) def test_pipeline_element_conformance(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert isinstance(ps, PipelineElementProtocol) def test_identity_delegated(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.identity_structure() == simple_source.identity_structure() def test_content_hash_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.content_hash() == simple_source.content_hash() def test_pipeline_hash_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.pipeline_hash() == simple_source.pipeline_hash() @@ -86,16 +86,16 @@ def test_pipeline_hash_matches_source(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceCachePath: +class TestCachedSourceCachePath: def test_cache_path_contains_content_hash(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) path = ps.cache_path content_hash_str = simple_source.content_hash().to_string() assert any(content_hash_str in segment for segment in path) def test_cache_path_prefix(self, simple_source, db): prefix = ("my_project", "v1") - ps = PersistentSource( + ps = CachedSource( simple_source, cache_database=db, cache_path_prefix=prefix ) assert ps.cache_path[:2] == prefix @@ -104,8 +104,8 @@ def test_same_source_same_cache_path(self, simple_table, db): """Identical sources produce the same cache path.""" s1 = ArrowTableSource(simple_table, tag_columns=["name"], source_id="src") s2 = ArrowTableSource(simple_table, tag_columns=["name"], source_id="src") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path == ps2.cache_path def test_same_name_same_schema_same_cache_path(self, db): @@ -114,8 +114,8 @@ def test_same_name_same_schema_same_cache_path(self, db): t2 = pa.table({"k": ["b"], "v": [2]}) s1 = ArrowTableSource(t1, tag_columns=["k"], source_id="s") s2 = ArrowTableSource(t2, tag_columns=["k"], source_id="s") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path == ps2.cache_path def test_different_name_different_cache_path(self, db): @@ -123,8 +123,8 @@ def test_different_name_different_cache_path(self, db): t1 = pa.table({"k": ["a"], "v": [1]}) s1 = ArrowTableSource(t1, tag_columns=["k"], source_id="src_a") s2 = ArrowTableSource(t1, tag_columns=["k"], source_id="src_b") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path != ps2.cache_path def test_unnamed_different_data_different_cache_path(self, db): @@ -133,8 +133,8 @@ def test_unnamed_different_data_different_cache_path(self, db): t2 = pa.table({"k": ["b"], "v": [2]}) s1 = ArrowTableSource(t1, tag_columns=["k"]) s2 = ArrowTableSource(t2, tag_columns=["k"]) - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path != ps2.cache_path @@ -143,19 +143,19 @@ def test_unnamed_different_data_different_cache_path(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceSchema: +class TestCachedSourceSchema: def test_output_schema_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.output_schema() == simple_source.output_schema() def test_output_schema_with_system_tags(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.output_schema( columns={"system_tags": True} ) == simple_source.output_schema(columns={"system_tags": True}) def test_keys_match_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.keys() == simple_source.keys() @@ -164,28 +164,28 @@ def test_keys_match_source(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceStreaming: +class TestCachedSourceStreaming: def test_as_table_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) ps_table = ps.as_table() src_table = simple_source.as_table() assert ps_table.num_rows == src_table.num_rows assert set(ps_table.column_names) == set(src_table.column_names) def test_iter_packets_count(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) packets = list(ps.iter_packets()) assert len(packets) == 3 def test_iter_packets_tags_and_packets(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) for tag, packet in ps.iter_packets(): assert "name" in tag.keys() assert "age" in packet.keys() def test_system_tags_preserved(self, simple_source, db): """System tags flow through the cache correctly.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) table = ps.as_table(columns={"system_tags": True}) sys_tag_cols = [ c for c in table.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) @@ -206,7 +206,7 @@ def test_system_tags_preserved(self, simple_source, db): def test_source_info_preserved(self, simple_source, db): """Source info columns flow through the cache correctly.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) table = ps.as_table(columns={"source": True}) source_cols = [ c for c in table.column_names if c.startswith(constants.SOURCE_PREFIX) @@ -219,19 +219,19 @@ def test_source_info_preserved(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceCumulative: +class TestCachedSourceCumulative: def test_dedup_on_same_data(self, simple_source, db): """Running twice with the same data produces no duplicates.""" - ps1 = PersistentSource(simple_source, cache_database=db) + ps1 = CachedSource(simple_source, cache_database=db) ps1.run() - ps2 = PersistentSource(simple_source, cache_database=db) + ps2 = CachedSource(simple_source, cache_database=db) ps2.run() table = ps2.as_table() assert table.num_rows == 3 # no duplicates def test_clear_cache_rebuilds(self, simple_source, db): """clear_cache forces a fresh merge from DB on next access.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) t1 = ps.as_table() ps.clear_cache() t2 = ps.as_table() @@ -248,18 +248,18 @@ def test_cumulative_across_runs(self, db): # Different data → different content_hash → different cache_paths # So cumulative within the SAME cache_path requires same content_hash - ps1 = PersistentSource(s1, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) ps1.run() assert ps1.as_table().num_rows == 2 # Same data source: should dedup s1_again = ArrowTableSource(t1, tag_columns=["k"], source_id="shared") - ps1_again = PersistentSource(s1_again, cache_database=db) + ps1_again = CachedSource(s1_again, cache_database=db) ps1_again.run() assert ps1_again.as_table().num_rows == 2 # Different source (s2) has different cache_path - ps2 = PersistentSource(s2, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) ps2.run() assert ps2.as_table().num_rows == 3 @@ -269,9 +269,9 @@ def test_cumulative_across_runs(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceFieldResolution: +class TestCachedSourceFieldResolution: def test_resolve_field_delegates(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) value = ps.resolve_field("row_0", "age") expected = simple_source.resolve_field("row_0", "age") assert value == expected @@ -286,7 +286,7 @@ def test_resolve_field_with_record_id_column(self, db): source = ArrowTableSource( table, tag_columns=["user_id"], record_id_column="user_id", source_id="test" ) - ps = PersistentSource(source, cache_database=db) + ps = CachedSource(source, cache_database=db) assert ps.resolve_field("user_id=u1", "score") == 100 @@ -295,9 +295,9 @@ def test_resolve_field_with_record_id_column(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceIntegration: - def test_join_with_persistent_source(self, db): - """PersistentSource can be joined with another stream.""" +class TestCachedSourceIntegration: + def test_join_with_cached_source(self, db): + """CachedSource can be joined with another stream.""" from orcapod.core.operators import Join t1 = pa.table({"id": [1, 2, 3], "val_a": [10, 20, 30]}) @@ -305,8 +305,8 @@ def test_join_with_persistent_source(self, db): s1 = ArrowTableSource(t1, tag_columns=["id"], source_id="a") s2 = ArrowTableSource(t2, tag_columns=["id"], source_id="b") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) joined = Join()(ps1, ps2) table = joined.as_table() @@ -314,8 +314,8 @@ def test_join_with_persistent_source(self, db): assert "val_a" in table.column_names assert "val_b" in table.column_names - def test_function_pod_with_persistent_source(self, db): - """PersistentSource works as input to a FunctionPod.""" + def test_function_pod_with_cached_source(self, db): + """CachedSource works as input to a FunctionPod.""" from orcapod.core.function_pod import FunctionPod from orcapod.core.packet_function import PythonPacketFunction @@ -327,10 +327,34 @@ def double_age(age: int) -> int: table = pa.table({"name": ["Alice", "Bob"], "age": [30, 25]}) source = ArrowTableSource(table, tag_columns=["name"], source_id="test") - ps = PersistentSource(source, cache_database=db) + ps = CachedSource(source, cache_database=db) result = pod(ps) packets = list(result.iter_packets()) assert len(packets) == 2 ages = [p.as_dict()["doubled_age"] for _, p in packets] assert sorted(ages) == [50, 60] + + +class TestCachedConvenienceMethod: + """Test the ``RootSource.cached()`` convenience method.""" + + def test_cached_returns_cached_source(self, simple_source, db): + cached = simple_source.cached(cache_database=db) + assert isinstance(cached, CachedSource) + + def test_cached_with_path_prefix(self, simple_source, db): + cached = simple_source.cached( + cache_database=db, + cache_path_prefix=("my", "prefix"), + ) + assert cached.cache_path[:2] == ("my", "prefix") + + def test_cached_data_matches_source(self, simple_source, db): + cached = simple_source.cached(cache_database=db) + original_table = simple_source.as_table() + cached_table = cached.as_table() + + # Same column names and row count + assert set(original_table.column_names) == set(cached_table.column_names) + assert original_table.num_rows == cached_table.num_rows diff --git a/tests/test_core/test_caching_integration.py b/tests/test_core/test_caching_integration.py index 0b103b60..343e37bc 100644 --- a/tests/test_core/test_caching_integration.py +++ b/tests/test_core/test_caching_integration.py @@ -2,7 +2,7 @@ Integration tests: all three pod caching strategies working end-to-end. Covers: -1. PersistentSource — always-on cache scoped to content_hash() +1. CachedSource — always-on cache scoped to content_hash() - DeltaTableSource with canonical source_id (defaults to dir name) - Named sources: same name + same schema = same identity (data-independent) - Unnamed sources: identity determined by table hash (data-dependent) @@ -24,7 +24,7 @@ from orcapod.core.nodes import PersistentFunctionNode, PersistentOperatorNode from orcapod.core.operators import Join from orcapod.core.packet_function import PythonPacketFunction -from orcapod.core.sources import ArrowTableSource, DeltaTableSource, PersistentSource +from orcapod.core.sources import ArrowTableSource, DeltaTableSource, CachedSource from orcapod.databases import InMemoryArrowDatabase from orcapod.types import CacheMode @@ -121,7 +121,7 @@ def pod(): # --------------------------------------------------------------------------- -# 1. PersistentSource — source pod caching +# 1. CachedSource — source pod caching # --------------------------------------------------------------------------- @@ -135,11 +135,11 @@ def test_delta_source_id_defaults_to_dir_name(self, clinic_a): def test_different_sources_get_different_cache_paths(self, clinic_a, source_db): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -147,7 +147,7 @@ def test_different_sources_get_different_cache_paths(self, clinic_a, source_db): def test_cache_populates_on_run(self, clinic_a, source_db): patients_path, _ = clinic_a - ps = PersistentSource( + ps = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -158,12 +158,12 @@ def test_cache_populates_on_run(self, clinic_a, source_db): def test_dedup_on_rerun(self, clinic_a, source_db): patients_path, _ = clinic_a - ps1 = PersistentSource( + ps1 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) ps1.run() - ps2 = PersistentSource( + ps2 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -176,7 +176,7 @@ def test_named_source_same_name_same_schema_same_identity( """Same dir name + same schema = same content_hash regardless of data.""" patients_path, _ = clinic_a src1 = DeltaTableSource(patients_path, tag_columns=["patient_id"]) - ps1 = PersistentSource(src1, cache_database=source_db) + ps1 = CachedSource(src1, cache_database=source_db) # Overwrite with different data, same schema _write_delta( @@ -192,7 +192,7 @@ def test_named_source_same_name_same_schema_same_identity( mode="overwrite", ) src2 = DeltaTableSource(patients_path, tag_columns=["patient_id"]) - ps2 = PersistentSource(src2, cache_database=source_db) + ps2 = CachedSource(src2, cache_database=source_db) assert src1.source_id == src2.source_id assert ps1.content_hash() == ps2.content_hash() @@ -201,7 +201,7 @@ def test_named_source_same_name_same_schema_same_identity( def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): """New rows from updated data accumulate in the same cache table.""" patients_path, _ = clinic_a - ps1 = PersistentSource( + ps1 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -221,7 +221,7 @@ def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): ), mode="overwrite", ) - ps2 = PersistentSource( + ps2 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -270,11 +270,11 @@ def test_function_node_stores_records( self, clinic_a, source_db, pipeline_db, result_db, pod ): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -300,11 +300,11 @@ def test_cross_source_sharing_same_pipeline_path( patients_b, labs_b = clinic_b # Pipeline A - pa_src = PersistentSource( + pa_src = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - la_src = PersistentSource( + la_src = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -316,11 +316,11 @@ def test_cross_source_sharing_same_pipeline_path( ) # Pipeline B - pb_src = PersistentSource( + pb_src = CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ) - lb_src = PersistentSource( + lb_src = CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -344,11 +344,11 @@ def test_cross_source_records_accumulate_in_shared_table( fn_a = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ), @@ -363,11 +363,11 @@ def test_cross_source_records_accumulate_in_shared_table( fn_b = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ), @@ -388,11 +388,11 @@ def test_cross_source_records_accumulate_in_shared_table( class TestOperatorPodCaching: def _make_joined_streams(self, clinic_a, source_db): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -469,19 +469,19 @@ def test_content_hash_scoping_isolates_source_combinations( patients_a, labs_a = clinic_a patients_b, labs_b = clinic_b - pa_src = PersistentSource( + pa_src = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - la_src = PersistentSource( + la_src = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) - pb_src = PersistentSource( + pb_src = CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ) - lb_src = PersistentSource( + lb_src = CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -511,17 +511,17 @@ def test_full_pipeline_source_to_function_to_operator( self, clinic_a, clinic_b, source_db, pipeline_db, result_db, operator_db, pod ): """ - Full pipeline: DeltaTableSource → PersistentSource → Join → + Full pipeline: DeltaTableSource → CachedSource → Join → PersistentFunctionNode → PersistentOperatorNode (LOG + REPLAY). """ patients_a, labs_a = clinic_a - # Step 1: PersistentSource - patients = PersistentSource( + # Step 1: CachedSource + patients = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -563,11 +563,11 @@ def test_full_pipeline_source_to_function_to_operator( fn_node_b = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ), diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index da165113..54a40444 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -2,8 +2,8 @@ Tests for the Pipeline class. Verifies that Pipeline (a GraphTracker subclass) correctly wraps all nodes -as persistent variants during compile(): -- Leaf streams → PersistentSourceNode +during compile(): +- Leaf streams → SourceNode - Function pod invocations → PersistentFunctionNode - Operator invocations → PersistentOperatorNode """ @@ -17,14 +17,17 @@ from orcapod.core.executors import PacketFunctionExecutorBase from orcapod.core.function_pod import FunctionPod -from orcapod.core.nodes import PersistentFunctionNode, PersistentOperatorNode +from orcapod.core.nodes import ( + PersistentFunctionNode, + PersistentOperatorNode, + SourceNode, +) from orcapod.core.operators import Join, MapPackets from orcapod.core.packet_function import PythonPacketFunction from orcapod.core.sources import ArrowTableSource from orcapod.databases import InMemoryArrowDatabase -from orcapod.pipeline import PersistentSourceNode, Pipeline +from orcapod.pipeline import Pipeline from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol -from orcapod.types import SourceCacheMode # --------------------------------------------------------------------------- # Helpers @@ -71,7 +74,7 @@ def function_db(): # --------------------------------------------------------------------------- -# Tests: compile wraps leaf streams as PersistentSourceNode +# Tests: compile wraps leaf streams as SourceNode # --------------------------------------------------------------------------- @@ -89,37 +92,14 @@ def test_compile_wraps_leaf_streams_as_persistent_source_node(self, pipeline_db) assert len(pipeline.compiled_nodes) > 0 assert pipeline._node_graph is not None - # The node graph should contain PersistentSourceNode instances + # The node graph should contain SourceNode instances source_nodes = [ n for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) ] assert len(source_nodes) == 2 - def test_persistent_source_node_cache_path_prefix(self, pipeline_db): - src_a, _ = _make_two_sources() - pipeline = Pipeline(name="my_pipeline", pipeline_database=pipeline_db) - - with pipeline: - # Use a simple unary operator to trigger a recording - MapPackets({"value": "val"})(src_a, label="mapper") - - # Find the PersistentSourceNode - assert pipeline._node_graph is not None - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - assert len(source_nodes) == 1 - sn = source_nodes[0] - - # cache_path should start with pipeline name prefix - assert sn.cache_path[:1] == ("my_pipeline",) - assert sn.cache_path[1] == "source" - assert sn.cache_path[2].startswith("node:") - # --------------------------------------------------------------------------- # Tests: compile creates PersistentFunctionNode @@ -361,11 +341,6 @@ def test_pipeline_path_prefix_scoping(self, pipeline_db): # Check function node adder = pipeline.adder assert adder.pipeline_path[0] == "scoped" - assert pipeline._node_graph is not None - # Check source nodes - for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - assert n.cache_path[0] == "scoped" # --------------------------------------------------------------------------- @@ -412,14 +387,6 @@ def test_end_to_end_source_join_function(self, pipeline_db): # Run the pipeline pipeline.run() - assert pipeline._node_graph is not None - # Source nodes should have cached data - for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - records = n.get_all_records() - assert records is not None - assert records.num_rows == 2 - # Function node should have results fn_records = pipeline.adder.get_all_records() assert fn_records is not None @@ -549,10 +516,10 @@ def test_second_pipeline_from_first_pipeline_node(self, pipeline_db): 220, ] - # pipe_b's source nodes wrap pipe_a.adder as a PersistentSourceNode + # pipe_b's source nodes wrap pipe_a.adder as a SourceNode assert pipe_b._node_graph is not None source_nodes = [ - n for n in pipe_b._node_graph.nodes() if isinstance(n, PersistentSourceNode) + n for n in pipe_b._node_graph.nodes() if isinstance(n, SourceNode) ] assert len(source_nodes) == 1 @@ -1010,7 +977,7 @@ def test_recompile_preserves_existing_node_objects(self, pipeline_db): assert "renamer" in pipeline.compiled_nodes def test_recompile_preserves_existing_source_nodes(self, pipeline_db): - """PersistentSourceNode objects from first compile survive second compile.""" + """SourceNode objects from first compile survive second compile.""" src_a, src_b = _make_two_sources() pipeline = Pipeline(name="incr_src", pipeline_database=pipeline_db) @@ -1022,7 +989,7 @@ def test_recompile_preserves_existing_source_nodes(self, pipeline_db): first_source_nodes = { id(n) for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) } # Extend with another operation @@ -1032,7 +999,7 @@ def test_recompile_preserves_existing_source_nodes(self, pipeline_db): second_source_nodes = { id(n) for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) } # All original source nodes should be preserved (same object ids) @@ -1077,7 +1044,7 @@ class TestCompileDoesNotTriggerExecution: triggering upstream iter_packets / run / as_table materialisation.""" def test_compile_does_not_trigger_source_materialization(self, pipeline_db): - """Operators followed by a function pod: compile should not execute anything.""" + """Compile should not trigger any computation or database writes.""" src_a, src_b = _make_two_sources() pf = PythonPacketFunction(add_values, output_keys="total") pod = FunctionPod(pf) @@ -1087,23 +1054,9 @@ def test_compile_does_not_trigger_source_materialization(self, pipeline_db): joined = Join()(src_a, src_b) pod(joined, label="adder") - # After compile, persistent source nodes should NOT have materialised - # their cache yet (i.e. _cached_stream is still None). - assert pipeline._node_graph is not None - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - for sn in source_nodes: - assert sn._cached_stream is None, ( - "PersistentSourceNode should not have materialised during compile()" - ) - - # The pipeline and function databases should be empty — nothing has - # been written yet because no upstream was triggered. - assert pipeline_db.get_all_records(source_nodes[0].cache_path) is None - assert pipeline_db.get_all_records(source_nodes[1].cache_path) is None + # After compile, the pipeline database should be empty — compile() + # only builds the graph, it doesn't execute any nodes. + assert pipeline.adder.get_all_records() is None # Running the pipeline should still work correctly after lazy compile. pipeline.run() @@ -1238,107 +1191,34 @@ def test_per_node_opts_override_pipeline_opts(self, pipeline_db): assert pipeline.doubler.executor.opts.get("num_cpus") == 2 -# --------------------------------------------------------------------------- -# Tests: SourceCacheMode -# --------------------------------------------------------------------------- - - -class TestSourceCacheModeOff: - """Verify that source_cache_mode=OFF bypasses database interaction.""" +class TestSourceNodeNoCaching: + """Verify that SourceNode does not cache — caching is a source-level concern.""" - def test_compile_passes_cache_mode_to_source_nodes(self, pipeline_db): - src_a, src_b = _make_two_sources() - pipeline = Pipeline( - name="test_pipe", - pipeline_database=pipeline_db, - source_cache_mode=SourceCacheMode.OFF, - ) - with pipeline: - _ = Join()(src_a, src_b) - - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - assert len(source_nodes) == 2 - for sn in source_nodes: - assert sn.source_cache_mode == SourceCacheMode.OFF - - def test_off_mode_no_db_writes(self, pipeline_db): - """In OFF mode, run() should not write anything to the cache DB.""" - src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) - pipeline = Pipeline( - name="test_pipe", - pipeline_database=pipeline_db, - source_cache_mode=SourceCacheMode.OFF, - ) - with pipeline: - MapPackets({"value": "val"})(src, label="mapper") - - pipeline.run() - - # Find the PersistentSourceNode - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - assert len(source_nodes) == 1 - sn = source_nodes[0] - - # No records should be in the DB - assert sn.get_all_records() is None - - def test_off_mode_data_flows_through(self, pipeline_db): - """In OFF mode, downstream nodes should still receive data.""" + def test_source_nodes_do_not_write_to_db(self, pipeline_db): + """Source nodes should not write anything to the pipeline DB.""" src_a, src_b = _make_two_sources() pf = PythonPacketFunction(add_values, output_keys="total") pod = FunctionPod(packet_function=pf) - pipeline = Pipeline( - name="test_pipe", - pipeline_database=pipeline_db, - source_cache_mode=SourceCacheMode.OFF, - ) + pipeline = Pipeline(name="test_pipe", pipeline_database=pipeline_db) with pipeline: joined = Join()(src_a, src_b) pod(joined, label="adder") pipeline.run() - # Function node should have computed results + # Function node should have computed results (pipeline works) records = pipeline.adder.get_all_records() assert records is not None assert records.num_rows == 2 - def test_full_mode_writes_to_db(self, pipeline_db): - """Sanity check: FULL mode (default) does write to DB.""" - src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) - pipeline = Pipeline( - name="test_pipe", - pipeline_database=pipeline_db, - source_cache_mode=SourceCacheMode.FULL, - ) - with pipeline: - MapPackets({"value": "val"})(src, label="mapper") - - pipeline.run() - + # Source nodes are plain SourceNode — no caching, no DB writes source_nodes = [ n for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) ] - assert len(source_nodes) == 1 - sn = source_nodes[0] - - # Records should be in the DB - records = sn.get_all_records() - assert records is not None - assert records.num_rows == 2 - - def test_default_is_full(self, pipeline_db): - """Pipeline defaults to FULL source cache mode.""" - pipeline = Pipeline(name="test", pipeline_database=pipeline_db) - assert pipeline._source_cache_mode == SourceCacheMode.FULL + assert len(source_nodes) == 2 + for sn in source_nodes: + assert not hasattr(sn, "cache_path") + assert not hasattr(sn, "get_all_records") From 926341d3468a9190984c70acbf9521572f94ba1a Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 00:50:46 +0000 Subject: [PATCH 3/7] test(source-node): add tests for SourceNode delegation and CachedSource pipeline integration - SourceNode.as_table() delegates to wrapped stream - SourceNode.iter_packets() delegates to wrapped stream - SourceNode.run() is a no-op - Pipeline with CachedSource input works end-to-end (source caching in source_db, pipeline execution in pipeline_db) https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- tests/test_core/test_tracker.py | 22 ++++++++++++++++++++++ tests/test_pipeline/test_pipeline.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/test_core/test_tracker.py b/tests/test_core/test_tracker.py index 142c54a6..b3ad8ab6 100644 --- a/tests/test_core/test_tracker.py +++ b/tests/test_core/test_tracker.py @@ -143,6 +143,28 @@ def test_delegates_keys(self): node = SourceNode(stream=stream) assert node.keys() == stream.keys() + def test_delegates_as_table(self): + stream = _make_stream() + node = SourceNode(stream=stream) + node_table = node.as_table() + stream_table = stream.as_table() + assert node_table.equals(stream_table) + + def test_delegates_iter_packets(self): + stream = _make_stream() + node = SourceNode(stream=stream) + node_packets = list(node.iter_packets()) + stream_packets = list(stream.iter_packets()) + assert len(node_packets) == len(stream_packets) + + def test_run_is_noop(self): + stream = _make_stream() + node = SourceNode(stream=stream) + # run() should succeed without side effects + node.run() + # Data is still accessible after run() + assert node.as_table().num_rows == stream.as_table().num_rows + # --------------------------------------------------------------------------- # BasicTrackerManager diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 54a40444..2e3bc490 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1222,3 +1222,31 @@ def test_source_nodes_do_not_write_to_db(self, pipeline_db): for sn in source_nodes: assert not hasattr(sn, "cache_path") assert not hasattr(sn, "get_all_records") + + def test_pipeline_with_cached_source_input(self, pipeline_db): + """CachedSource as pipeline input: source caching + pipeline execution.""" + from orcapod.core.sources import CachedSource + + src_a, src_b = _make_two_sources() + source_db = InMemoryArrowDatabase() + cached_a = CachedSource(src_a, cache_database=source_db) + cached_b = CachedSource(src_b, cache_database=source_db) + + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(packet_function=pf) + + pipeline = Pipeline(name="cached_src", pipeline_database=pipeline_db) + with pipeline: + joined = Join()(cached_a, cached_b) + pod(joined, label="adder") + + pipeline.run() + + # Function node computed results + records = pipeline.adder.get_all_records() + assert records is not None + assert records.num_rows == 2 + + # Source data was cached in source_db (not pipeline_db) + assert cached_a.get_all_records() is not None + assert cached_b.get_all_records() is not None From 51a7cca0a9eff063f58a6fe95796b93eb0a11fe2 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 04:41:07 +0000 Subject: [PATCH 4/7] refactor(cached-source): remove run() method, use flow() for eager caching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CachedSource.run() was not part of the StreamProtocol or SourceProtocol interfaces. Replace with flow() (inherited from StreamBase) which triggers the same caching via iter_packets() → _ensure_stream(). Updated all tests. https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- src/orcapod/core/sources/cached_source.py | 4 ---- tests/test_core/sources/test_cached_source.py | 10 +++++----- tests/test_core/test_caching_integration.py | 12 ++++++------ 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/orcapod/core/sources/cached_source.py b/src/orcapod/core/sources/cached_source.py index 1e635ed1..f73fdee1 100644 --- a/src/orcapod/core/sources/cached_source.py +++ b/src/orcapod/core/sources/cached_source.py @@ -167,10 +167,6 @@ def clear_cache(self) -> None: """Discard in-memory cached stream (forces rebuild on next access).""" self._cached_stream = None - def run(self) -> None: - """Eagerly populate the cache with live source data.""" - self._ensure_stream() - def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: self._ensure_stream() assert self._cached_stream is not None diff --git a/tests/test_core/sources/test_cached_source.py b/tests/test_core/sources/test_cached_source.py index ae2d2dd2..581a98c8 100644 --- a/tests/test_core/sources/test_cached_source.py +++ b/tests/test_core/sources/test_cached_source.py @@ -223,9 +223,9 @@ class TestCachedSourceCumulative: def test_dedup_on_same_data(self, simple_source, db): """Running twice with the same data produces no duplicates.""" ps1 = CachedSource(simple_source, cache_database=db) - ps1.run() + ps1.flow() ps2 = CachedSource(simple_source, cache_database=db) - ps2.run() + ps2.flow() table = ps2.as_table() assert table.num_rows == 3 # no duplicates @@ -249,18 +249,18 @@ def test_cumulative_across_runs(self, db): # Different data → different content_hash → different cache_paths # So cumulative within the SAME cache_path requires same content_hash ps1 = CachedSource(s1, cache_database=db) - ps1.run() + ps1.flow() assert ps1.as_table().num_rows == 2 # Same data source: should dedup s1_again = ArrowTableSource(t1, tag_columns=["k"], source_id="shared") ps1_again = CachedSource(s1_again, cache_database=db) - ps1_again.run() + ps1_again.flow() assert ps1_again.as_table().num_rows == 2 # Different source (s2) has different cache_path ps2 = CachedSource(s2, cache_database=db) - ps2.run() + ps2.flow() assert ps2.as_table().num_rows == 3 diff --git a/tests/test_core/test_caching_integration.py b/tests/test_core/test_caching_integration.py index 343e37bc..e1b0ab16 100644 --- a/tests/test_core/test_caching_integration.py +++ b/tests/test_core/test_caching_integration.py @@ -145,13 +145,13 @@ def test_different_sources_get_different_cache_paths(self, clinic_a, source_db): ) assert patients.cache_path != labs.cache_path - def test_cache_populates_on_run(self, clinic_a, source_db): + def test_cache_populates_on_flow(self, clinic_a, source_db): patients_path, _ = clinic_a ps = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps.run() + ps.flow() records = ps.get_all_records() assert records is not None assert records.num_rows == 3 @@ -162,12 +162,12 @@ def test_dedup_on_rerun(self, clinic_a, source_db): DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps1.run() + ps1.flow() ps2 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps2.run() + ps2.flow() assert ps2.get_all_records().num_rows == 3 def test_named_source_same_name_same_schema_same_identity( @@ -205,7 +205,7 @@ def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps1.run() + ps1.flow() assert ps1.get_all_records().num_rows == 3 # Update Delta table: add p4 @@ -225,7 +225,7 @@ def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps2.run() + ps2.flow() # 3 original + 1 new, existing rows deduped assert ps2.get_all_records().num_rows == 4 From a4452ced77fe57c8485b4c5f81522b234760e89c Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 04:45:39 +0000 Subject: [PATCH 5/7] feat(cached-source): add staleness detection against wrapped source CachedSource._ensure_stream() now checks if the wrapped source's last_modified is newer than the cached stream's timestamp. If stale, the in-memory cache is discarded and rebuilt from the DB + live data. Adds test_source_modified_time_triggers_rebuild verifying that updating the source's modified time causes CachedSource to rebuild on next access. https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- src/orcapod/core/sources/cached_source.py | 12 +++++++++++- tests/test_core/sources/test_cached_source.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/orcapod/core/sources/cached_source.py b/src/orcapod/core/sources/cached_source.py index f73fdee1..35c9dad0 100644 --- a/src/orcapod/core/sources/cached_source.py +++ b/src/orcapod/core/sources/cached_source.py @@ -157,8 +157,18 @@ def _build_merged_stream(self) -> ArrowTableStream: tag_keys = self._source.keys()[0] return ArrowTableStream(all_records, tag_columns=tag_keys) + def _is_source_stale(self) -> bool: + """True if the wrapped source has been modified since the last build.""" + own_time = self.last_modified + if own_time is None: + return True + src_time = self._source.last_modified + return src_time is None or src_time > own_time + def _ensure_stream(self) -> None: - """Build the merged stream on first access.""" + """Build the merged stream on first access, or rebuild if source is stale.""" + if self._cached_stream is not None and self._is_source_stale(): + self._cached_stream = None if self._cached_stream is None: self._cached_stream = self._build_merged_stream() self._update_modified_time() diff --git a/tests/test_core/sources/test_cached_source.py b/tests/test_core/sources/test_cached_source.py index 581a98c8..f2a7ac2c 100644 --- a/tests/test_core/sources/test_cached_source.py +++ b/tests/test_core/sources/test_cached_source.py @@ -237,6 +237,22 @@ def test_clear_cache_rebuilds(self, simple_source, db): t2 = ps.as_table() assert t1.num_rows == t2.num_rows + def test_source_modified_time_triggers_rebuild(self, simple_source, db): + """Updating the wrapped source's modified time triggers a cache rebuild.""" + ps = CachedSource(simple_source, cache_database=db) + # First access: build and cache + t1 = ps.as_table() + assert t1.num_rows == 3 + + # Simulate the source being updated (e.g. new data loaded) + simple_source._update_modified_time() + + # Next access should detect staleness and rebuild + t2 = ps.as_table() + assert t2.num_rows == 3 + # Verify CachedSource's own modified time was updated past the source's + assert not ps._is_source_stale() + def test_cumulative_across_runs(self, db): """Data from different runs accumulates in the cache.""" # Use a single source_id to make them share the same canonical identity From 0497230045d8641ef83d55296edb60ffe16efef0 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 04:47:21 +0000 Subject: [PATCH 6/7] refactor(cached-source): override is_stale property instead of private method Replace _is_source_stale() with a proper is_stale property override on StreamBase. CachedSource is a RootSource (no upstreams/producer) but still depends on the wrapped source's modification time. https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- src/orcapod/core/sources/cached_source.py | 11 ++++++++--- tests/test_core/sources/test_cached_source.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/orcapod/core/sources/cached_source.py b/src/orcapod/core/sources/cached_source.py index 35c9dad0..e96a7a4b 100644 --- a/src/orcapod/core/sources/cached_source.py +++ b/src/orcapod/core/sources/cached_source.py @@ -157,8 +157,13 @@ def _build_merged_stream(self) -> ArrowTableStream: tag_keys = self._source.keys()[0] return ArrowTableStream(all_records, tag_columns=tag_keys) - def _is_source_stale(self) -> bool: - """True if the wrapped source has been modified since the last build.""" + @property + def is_stale(self) -> bool: + """True if the wrapped source has been modified since the last build. + + Overrides ``StreamBase.is_stale`` because CachedSource is a RootSource + (no upstreams/producer) yet still depends on the wrapped ``_source``. + """ own_time = self.last_modified if own_time is None: return True @@ -167,7 +172,7 @@ def _is_source_stale(self) -> bool: def _ensure_stream(self) -> None: """Build the merged stream on first access, or rebuild if source is stale.""" - if self._cached_stream is not None and self._is_source_stale(): + if self._cached_stream is not None and self.is_stale: self._cached_stream = None if self._cached_stream is None: self._cached_stream = self._build_merged_stream() diff --git a/tests/test_core/sources/test_cached_source.py b/tests/test_core/sources/test_cached_source.py index f2a7ac2c..5f27075c 100644 --- a/tests/test_core/sources/test_cached_source.py +++ b/tests/test_core/sources/test_cached_source.py @@ -251,7 +251,7 @@ def test_source_modified_time_triggers_rebuild(self, simple_source, db): t2 = ps.as_table() assert t2.num_rows == 3 # Verify CachedSource's own modified time was updated past the source's - assert not ps._is_source_stale() + assert not ps.is_stale def test_cumulative_across_runs(self, db): """Data from different runs accumulates in the cache.""" From 71766583f5b9d909abe81a25f26f128f7ca6cba7 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 05:48:43 +0000 Subject: [PATCH 7/7] refactor(nodes): delegate data_context to wrapped entity, address Copilot review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Nodes are now transparent wrappers for data context — they always delegate to their primary wrapped entity instead of holding their own: - SourceNode → wrapped stream - FunctionNode → function pod - OperatorNode → operator pod This ensures consistent hashing (content_hash/pipeline_hash use the same semantic hasher as the wrapped entity) and eliminates the risk of context mismatch when Pipeline.compile() constructs nodes. Also addresses Copilot PR review comments: - Fix misleading Pipeline docstring ("persistent variant" → "execution-ready nodes") - Tighten RootSource.cached() type annotations (Any → ArrowDatabaseProtocol, RootSource → CachedSource) - Add DESIGN_ISSUES.md entry T2 for config/context delegation chain review https://claude.ai/code/session_016x3vkNoCTPW6GdzVNRZeAZ --- DESIGN_ISSUES.md | 23 ++++++++++ src/orcapod/core/nodes/function_node.py | 17 ++++---- src/orcapod/core/nodes/operator_node.py | 17 ++++---- src/orcapod/core/nodes/source_node.py | 15 ++++--- src/orcapod/core/sources/base.py | 10 +++-- src/orcapod/pipeline/graph.py | 2 +- tests/test_core/test_tracker.py | 56 +++++++++++++++++++++++++ 7 files changed, 114 insertions(+), 26 deletions(-) diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 72ac856a..ed5055fe 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -873,3 +873,26 @@ For derived sources (e.g., `DerivedSource`), the stream may not have a meaningfu `identity_structure()`. Needs an isinstance check or protocol-based dispatch. --- + +## `src/orcapod/core/nodes/` — Config and context delegation chain + +### T2 — `orcapod_config` not on any protocol; delegation chain needs review +**Status:** open +**Severity:** medium + +Nodes (`SourceNode`, `FunctionNode`, `OperatorNode`) now delegate `data_context` to their +wrapped entity via property overrides, ensuring transparent context pass-through. However, +`orcapod_config` is only on `DataContextMixin` (concrete base), not on any protocol +(`StreamProtocol`, `PodProtocol`, etc.). + +Open questions: +1. Should `orcapod_config` be added to a protocol (e.g. `TraceableProtocol`)? + Adding it couples protocol consumers to the `Config` type. Leaving it off means + nodes can't transparently delegate config the same way they delegate context. +2. Should `Pipeline` optionally hold `data_context` and/or `config` to allow + pipeline-level overrides that propagate to all nodes during `compile()`? +3. The current chain is: Pipeline (no context) → Node (delegates to wrapped entity) → + wrapped entity (owns context). Should there be a way for Pipeline to inject a + context override? + +--- diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index b6e645d4..1c5b6b4f 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -69,7 +69,6 @@ def __init__( input_stream: StreamProtocol, tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): if tracker_manager is None: @@ -79,11 +78,7 @@ def __init__( # FunctionPod used for the `producer` property and pipeline identity self._function_pod = function_pod - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) # validate the input stream _, incoming_packet_types = input_stream.output_schema() @@ -118,6 +113,14 @@ def __init__( def producer(self) -> FunctionPodProtocol: return self._function_pod + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._function_pod.data_context_key) + + @property + def data_context_key(self) -> str: + return self._function_pod.data_context_key + @property def executor(self) -> PacketFunctionExecutorProtocol | None: """The executor set on the underlying packet function.""" @@ -432,7 +435,6 @@ def __init__( pipeline_path_prefix: tuple[str, ...] = (), tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): super().__init__( @@ -440,7 +442,6 @@ def __init__( input_stream=input_stream, tracker_manager=tracker_manager, label=label, - data_context=data_context, config=config, ) diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index 217309e7..dcf3f5f2 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -50,7 +50,6 @@ def __init__( input_streams: tuple[StreamProtocol, ...] | list[StreamProtocol], tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): if tracker_manager is None: @@ -60,11 +59,7 @@ def __init__( self._operator = operator self._input_streams = tuple(input_streams) - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) # Validate inputs eagerly self._operator.validate_inputs(*self._input_streams) @@ -92,6 +87,14 @@ def pipeline_identity_structure(self) -> Any: def producer(self) -> OperatorPodProtocol: return self._operator + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._operator.data_context_key) + + @property + def data_context_key(self) -> str: + return self._operator.data_context_key + @property def upstreams(self) -> tuple[StreamProtocol, ...]: return self._input_streams @@ -222,7 +225,6 @@ def __init__( pipeline_path_prefix: tuple[str, ...] = (), tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): super().__init__( @@ -230,7 +232,6 @@ def __init__( input_streams=input_streams, tracker_manager=tracker_manager, label=label, - data_context=data_context, config=config, ) diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index 34a66483..4ffa63ea 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -25,16 +25,19 @@ def __init__( self, stream: cp.StreamProtocol, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) self.stream = stream + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self.stream.data_context_key) + + @property + def data_context_key(self) -> str: + return self.stream.data_context_key + def computed_label(self) -> str | None: return self.stream.label diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py index 60689a9e..dd7093fa 100644 --- a/src/orcapod/core/sources/base.py +++ b/src/orcapod/core/sources/base.py @@ -1,11 +1,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from orcapod.core.streams.base import StreamBase from orcapod.errors import FieldNotResolvableError from orcapod.protocols.core_protocols import StreamProtocol +if TYPE_CHECKING: + from orcapod.core.sources.cached_source import CachedSource + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + class RootSource(StreamBase): """ @@ -144,10 +148,10 @@ def upstreams(self) -> tuple[StreamProtocol, ...]: def cached( self, - cache_database: Any, + cache_database: ArrowDatabaseProtocol, cache_path_prefix: tuple[str, ...] = (), **kwargs: Any, - ) -> "RootSource": + ) -> CachedSource: """Return a ``CachedSource`` wrapping this source. Args: diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index feda71d9..f8ece817 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -54,7 +54,7 @@ class Pipeline(GraphTracker): During the ``with`` block, operator and function pod invocations are recorded as non-persistent nodes (same as ``GraphTracker``). On context - exit, ``compile()`` replaces every node with its persistent variant: + exit, ``compile()`` rewires the graph into execution-ready nodes: - Leaf streams → ``SourceNode`` (thin wrapper for graph vertex) - Function pod invocations → ``PersistentFunctionNode`` diff --git a/tests/test_core/test_tracker.py b/tests/test_core/test_tracker.py index b3ad8ab6..d4d902bc 100644 --- a/tests/test_core/test_tracker.py +++ b/tests/test_core/test_tracker.py @@ -165,6 +165,62 @@ def test_run_is_noop(self): # Data is still accessible after run() assert node.as_table().num_rows == stream.as_table().num_rows + def test_delegates_data_context_key(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context_key == stream.data_context_key + + def test_delegates_data_context(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context.context_key == stream.data_context_key + + +# --------------------------------------------------------------------------- +# Node context delegation +# --------------------------------------------------------------------------- + + +class TestNodeContextDelegation: + """All node types delegate data_context to their wrapped entity.""" + + def test_source_node_context_matches_stream(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context_key == stream.data_context_key + assert node.data_context.context_key == stream.data_context_key + + def test_function_node_context_matches_pod(self): + stream = _make_stream() + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + node = FunctionNode(function_pod=pod, input_stream=stream) + assert node.data_context_key == pod.data_context_key + assert node.data_context.context_key == pod.data_context_key + + def test_operator_node_context_matches_operator(self): + stream = _make_two_col_stream() + op = SelectTagColumns("id") + node = OperatorNode(operator=op, input_streams=[stream]) + assert node.data_context_key == op.data_context_key + assert node.data_context.context_key == op.data_context_key + + def test_source_node_hash_consistent_with_stream(self): + stream = _make_stream() + node = SourceNode(stream=stream) + # Both should use the same hasher (from the same data context) + assert node.content_hash() == stream.content_hash() + assert node.pipeline_hash() == stream.pipeline_hash() + + def test_function_node_hash_uses_pod_context(self): + stream = _make_stream() + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + node = FunctionNode(function_pod=pod, input_stream=stream) + # Node should produce stable hashes without error + assert node.content_hash() is not None + assert node.pipeline_hash() is not None + # --------------------------------------------------------------------------- # BasicTrackerManager