diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 72ac856a..ed5055fe 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -873,3 +873,26 @@ For derived sources (e.g., `DerivedSource`), the stream may not have a meaningfu `identity_structure()`. Needs an isinstance check or protocol-based dispatch. --- + +## `src/orcapod/core/nodes/` — Config and context delegation chain + +### T2 — `orcapod_config` not on any protocol; delegation chain needs review +**Status:** open +**Severity:** medium + +Nodes (`SourceNode`, `FunctionNode`, `OperatorNode`) now delegate `data_context` to their +wrapped entity via property overrides, ensuring transparent context pass-through. However, +`orcapod_config` is only on `DataContextMixin` (concrete base), not on any protocol +(`StreamProtocol`, `PodProtocol`, etc.). + +Open questions: +1. Should `orcapod_config` be added to a protocol (e.g. `TraceableProtocol`)? + Adding it couples protocol consumers to the `Config` type. Leaving it off means + nodes can't transparently delegate config the same way they delegate context. +2. Should `Pipeline` optionally hold `data_context` and/or `config` to allow + pipeline-level overrides that propagate to all nodes during `compile()`? +3. The current chain is: Pipeline (no context) → Node (delegates to wrapped entity) → + wrapped entity (owns context). Should there be a way for Pipeline to inject a + context override? + +--- diff --git a/demo_pipeline.py b/demo_pipeline.py index c365530b..9b16f22d 100644 --- a/demo_pipeline.py +++ b/demo_pipeline.py @@ -23,7 +23,8 @@ from orcapod.core.packet_function import PythonPacketFunction from orcapod.core.sources import ArrowTableSource from orcapod.databases import DeltaTableDatabase, InMemoryArrowDatabase -from orcapod.pipeline import Pipeline, PersistentSourceNode +from orcapod.core.nodes import SourceNode +from orcapod.pipeline import Pipeline # --------------------------------------------------------------------------- @@ -96,10 +97,10 @@ def categorize(risk: float) -> str: for name, node in pipeline.compiled_nodes.items(): print(f" {name}: {type(node).__name__}") -print("\n Source nodes (PersistentSourceNode):") +print("\n Source nodes (SourceNode):") for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - print(f" cache_path: {n.cache_path}") + if isinstance(n, SourceNode): + print(f" label: {n.label}, stream: {type(n.stream).__name__}") # --- Access nodes by label --- print(f"\n pipeline.join_data -> {type(pipeline.join_data).__name__}") @@ -127,12 +128,6 @@ def categorize(risk: float) -> str: print(f" {cat_table.to_pandas()[['patient_id', 'category']].to_string(index=False)}") # --- Show what's in the database --- -print(f"\n Source records in DB:") -for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - records = n.get_all_records() - print(f" {n.cache_path[-1]}: {records.num_rows} rows") - fn_records = pipeline.compute_risk.get_all_records() print(f" Function records (compute_risk): {fn_records.num_rows} rows") @@ -320,7 +315,7 @@ def categorize(risk: float) -> str: print("=" * 70) print(""" Pipeline wraps ALL nodes as persistent variants automatically: - - Leaf streams -> PersistentSourceNode (DB-backed cache) + - Leaf streams -> SourceNode (graph vertex wrapper, no caching) - Operator calls -> PersistentOperatorNode (DB-backed cache) - Function pod calls -> PersistentFunctionNode (DB-backed cache) diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index b6e645d4..1c5b6b4f 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -69,7 +69,6 @@ def __init__( input_stream: StreamProtocol, tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): if tracker_manager is None: @@ -79,11 +78,7 @@ def __init__( # FunctionPod used for the `producer` property and pipeline identity self._function_pod = function_pod - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) # validate the input stream _, incoming_packet_types = input_stream.output_schema() @@ -118,6 +113,14 @@ def __init__( def producer(self) -> FunctionPodProtocol: return self._function_pod + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._function_pod.data_context_key) + + @property + def data_context_key(self) -> str: + return self._function_pod.data_context_key + @property def executor(self) -> PacketFunctionExecutorProtocol | None: """The executor set on the underlying packet function.""" @@ -432,7 +435,6 @@ def __init__( pipeline_path_prefix: tuple[str, ...] = (), tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): super().__init__( @@ -440,7 +442,6 @@ def __init__( input_stream=input_stream, tracker_manager=tracker_manager, label=label, - data_context=data_context, config=config, ) diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index 217309e7..dcf3f5f2 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -50,7 +50,6 @@ def __init__( input_streams: tuple[StreamProtocol, ...] | list[StreamProtocol], tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): if tracker_manager is None: @@ -60,11 +59,7 @@ def __init__( self._operator = operator self._input_streams = tuple(input_streams) - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) # Validate inputs eagerly self._operator.validate_inputs(*self._input_streams) @@ -92,6 +87,14 @@ def pipeline_identity_structure(self) -> Any: def producer(self) -> OperatorPodProtocol: return self._operator + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._operator.data_context_key) + + @property + def data_context_key(self) -> str: + return self._operator.data_context_key + @property def upstreams(self) -> tuple[StreamProtocol, ...]: return self._input_streams @@ -222,7 +225,6 @@ def __init__( pipeline_path_prefix: tuple[str, ...] = (), tracker_manager: TrackerManagerProtocol | None = None, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): super().__init__( @@ -230,7 +232,6 @@ def __init__( input_streams=input_streams, tracker_manager=tracker_manager, label=label, - data_context=data_context, config=config, ) diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index 34a66483..4ffa63ea 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -25,16 +25,19 @@ def __init__( self, stream: cp.StreamProtocol, label: str | None = None, - data_context: str | contexts.DataContext | None = None, config: Config | None = None, ): - super().__init__( - label=label, - data_context=data_context, - config=config, - ) + super().__init__(label=label, config=config) self.stream = stream + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self.stream.data_context_key) + + @property + def data_context_key(self) -> str: + return self.stream.data_context_key + def computed_label(self) -> str | None: return self.stream.label diff --git a/src/orcapod/core/sources/__init__.py b/src/orcapod/core/sources/__init__.py index 0cb810e0..ccb0e782 100644 --- a/src/orcapod/core/sources/__init__.py +++ b/src/orcapod/core/sources/__init__.py @@ -1,24 +1,24 @@ from .base import RootSource from .arrow_table_source import ArrowTableSource +from .cached_source import CachedSource from .csv_source import CSVSource from .data_frame_source import DataFrameSource from .delta_table_source import DeltaTableSource from .derived_source import DerivedSource from .dict_source import DictSource from .list_source import ListSource -from .persistent_source import PersistentSource from .source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry __all__ = [ "RootSource", "ArrowTableSource", + "CachedSource", "CSVSource", "DataFrameSource", "DeltaTableSource", "DerivedSource", "DictSource", "ListSource", - "PersistentSource", "SourceRegistry", "GLOBAL_SOURCE_REGISTRY", ] diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py index d1f213a2..dd7093fa 100644 --- a/src/orcapod/core/sources/base.py +++ b/src/orcapod/core/sources/base.py @@ -1,11 +1,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from orcapod.core.streams.base import StreamBase from orcapod.errors import FieldNotResolvableError from orcapod.protocols.core_protocols import StreamProtocol +if TYPE_CHECKING: + from orcapod.core.sources.cached_source import CachedSource + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + class RootSource(StreamBase): """ @@ -137,3 +141,32 @@ def producer(self) -> None: def upstreams(self) -> tuple[StreamProtocol, ...]: """Sources have no upstream dependencies.""" return () + + # ------------------------------------------------------------------------- + # Convenience — caching + # ------------------------------------------------------------------------- + + def cached( + self, + cache_database: ArrowDatabaseProtocol, + cache_path_prefix: tuple[str, ...] = (), + **kwargs: Any, + ) -> CachedSource: + """Return a ``CachedSource`` wrapping this source. + + Args: + cache_database: Database to store cached records in. + cache_path_prefix: Path prefix for the cache table. + **kwargs: Additional keyword arguments passed to ``CachedSource``. + + Returns: + A ``CachedSource`` that caches this source's output. + """ + from orcapod.core.sources.cached_source import CachedSource + + return CachedSource( + source=self, + cache_database=cache_database, + cache_path_prefix=cache_path_prefix, + **kwargs, + ) diff --git a/src/orcapod/core/sources/persistent_source.py b/src/orcapod/core/sources/cached_source.py similarity index 80% rename from src/orcapod/core/sources/persistent_source.py rename to src/orcapod/core/sources/cached_source.py index e53d3218..e96a7a4b 100644 --- a/src/orcapod/core/sources/persistent_source.py +++ b/src/orcapod/core/sources/cached_source.py @@ -21,25 +21,29 @@ logger = logging.getLogger(__name__) -class PersistentSource(RootSource): - """ - DB-backed wrapper around a RootSource that caches every packet. +class CachedSource(RootSource): + """DB-backed wrapper around a ``RootSource`` that caches every packet. - Implements StreamProtocol transparently so downstream consumers + Implements ``StreamProtocol`` transparently so downstream consumers are unaware of caching. Cache table is scoped to the source's ``content_hash()`` — each unique source gets its own table. - Behavior - -------- - - Cache is **always on**. - - On first access, live source data is stored in the cache table - (deduped by per-row content hash). - - Returns the union of all cached data (cumulative across runs). - - Semantic guarantee - ------------------ - The cache is a correct cumulative record. The union of cache + live - packets is the full set of data ever available from that source. + Behavior: + - Cache is **always on**. + - On first access, live source data is stored in the cache table + (deduped by per-row content hash). + - Returns the union of all cached data (cumulative across runs). + + Semantic guarantee: + The cache is a correct cumulative record. The union of cache + live + packets is the full set of data ever available from that source. + + Example:: + + source = ArrowTableSource(table, tag_columns=["id"]) + cached = CachedSource(source, cache_database=db) + # or equivalently: + cached = source.cached(cache_database=db) """ HASH_COLUMN_NAME = "_record_hash" @@ -153,8 +157,23 @@ def _build_merged_stream(self) -> ArrowTableStream: tag_keys = self._source.keys()[0] return ArrowTableStream(all_records, tag_columns=tag_keys) + @property + def is_stale(self) -> bool: + """True if the wrapped source has been modified since the last build. + + Overrides ``StreamBase.is_stale`` because CachedSource is a RootSource + (no upstreams/producer) yet still depends on the wrapped ``_source``. + """ + own_time = self.last_modified + if own_time is None: + return True + src_time = self._source.last_modified + return src_time is None or src_time > own_time + def _ensure_stream(self) -> None: - """Build the merged stream on first access.""" + """Build the merged stream on first access, or rebuild if source is stale.""" + if self._cached_stream is not None and self.is_stale: + self._cached_stream = None if self._cached_stream is None: self._cached_stream = self._build_merged_stream() self._update_modified_time() @@ -163,10 +182,6 @@ def clear_cache(self) -> None: """Discard in-memory cached stream (forces rebuild on next access).""" self._cached_stream = None - def run(self) -> None: - """Eagerly populate the cache with live source data.""" - self._ensure_stream() - def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: self._ensure_stream() assert self._cached_stream is not None diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 7495be39..278acf33 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,9 +1,7 @@ from .graph import Pipeline -from .nodes import PersistentSourceNode from .orchestrator import AsyncPipelineOrchestrator __all__ = [ "AsyncPipelineOrchestrator", "Pipeline", - "PersistentSourceNode", ] diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index c6538c01..f8ece817 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -5,9 +5,8 @@ import tempfile from typing import TYPE_CHECKING, Any -from orcapod.core.nodes import GraphNode +from orcapod.core.nodes import GraphNode, SourceNode from orcapod.core.tracker import GraphTracker -from orcapod.pipeline.nodes import PersistentSourceNode from orcapod.protocols import core_protocols as cp from orcapod.protocols import database_protocols as dbp from orcapod.types import PipelineConfig @@ -55,12 +54,16 @@ class Pipeline(GraphTracker): During the ``with`` block, operator and function pod invocations are recorded as non-persistent nodes (same as ``GraphTracker``). On context - exit, ``compile()`` replaces every node with its persistent variant: + exit, ``compile()`` rewires the graph into execution-ready nodes: - - Leaf streams → ``PersistentSourceNode`` + - Leaf streams → ``SourceNode`` (thin wrapper for graph vertex) - Function pod invocations → ``PersistentFunctionNode`` - Operator invocations → ``PersistentOperatorNode`` + Source caching is not a pipeline concern — sources that need caching + should be wrapped in a ``CachedSource`` before being used in the + pipeline. + All persistent nodes share the same ``pipeline_database`` and use ``pipeline_name`` as path prefix, scoping their cache tables. @@ -139,7 +142,7 @@ def compile(self) -> None: Walks the graph in topological order and creates: - - ``PersistentSourceNode`` for every leaf stream + - ``SourceNode`` for every leaf stream - ``PersistentFunctionNode`` for every function pod invocation - ``PersistentOperatorNode`` for every operator invocation @@ -175,13 +178,9 @@ def compile(self) -> None: continue if node_hash not in self._node_lut: - # -- Leaf stream: wrap in PersistentSourceNode -- + # -- Leaf stream: wrap in SourceNode -- stream = self._upstreams[node_hash] - persistent_node = PersistentSourceNode( - stream=stream, - cache_database=self._pipeline_database, - cache_path_prefix=self._pipeline_path_prefix, - ) + persistent_node = SourceNode(stream=stream) persistent_node_map[node_hash] = persistent_node else: node = self._node_lut[node_hash] @@ -260,7 +259,7 @@ def compile(self) -> None: continue attrs = self._hash_graph.nodes[node_hash] if not attrs.get("node_type"): - if isinstance(node, PersistentSourceNode): + if isinstance(node, SourceNode): attrs["node_type"] = "source" elif isinstance(node, PersistentFunctionNode): attrs["node_type"] = "function" diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index a8e7d505..5abc288c 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -1,164 +1,3 @@ -from __future__ import annotations - -import logging -from collections.abc import Iterator, Sequence -from typing import TYPE_CHECKING, Any - -from orcapod import contexts -from orcapod.channels import ReadableChannel, WritableChannel -from orcapod.config import Config -from orcapod.core.nodes import SourceNode -from orcapod.core.streams.arrow_table_stream import ArrowTableStream -from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol -from orcapod.protocols.database_protocols import ArrowDatabaseProtocol -from orcapod.types import ColumnConfig -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -logger = logging.getLogger(__name__) - - -class PersistentSourceNode(SourceNode): - """ - DB-backed wrapper around any stream, used by ``Pipeline.compile()`` - to cache leaf stream data. - - Extends ``SourceNode`` (which delegates identity/schema to the wrapped - stream) and adds: - - - Materialization of the stream's output into a cache database - - Per-row deduplication via content hash - - Cached ``ArrowTableStream`` for downstream consumption - - Cache path structure:: - - cache_path_prefix / source / node:{content_hash} - """ - - HASH_COLUMN_NAME = "_record_hash" - - def __init__( - self, - stream: StreamProtocol, - cache_database: ArrowDatabaseProtocol, - cache_path_prefix: tuple[str, ...] = (), - label: str | None = None, - data_context: str | contexts.DataContext | None = None, - config: Config | None = None, - ) -> None: - super().__init__( - stream=stream, - label=label, - data_context=data_context, - config=config, - ) - self._cache_database = cache_database - self._cache_path_prefix = cache_path_prefix - self._cached_stream: ArrowTableStream | None = None - - # ------------------------------------------------------------------------- - # Cache path - # ------------------------------------------------------------------------- - - @property - def cache_path(self) -> tuple[str, ...]: - """Cache table path, scoped to the wrapped stream's content hash.""" - return self._cache_path_prefix + ( - "source", - f"node:{self.stream.content_hash().to_string()}", - ) - - # ------------------------------------------------------------------------- - # Caching logic - # ------------------------------------------------------------------------- - - def _build_cached_stream(self) -> ArrowTableStream: - """ - Materialize the wrapped stream, store rows in the cache DB - (deduped by per-row hash), and return the cached table as an - ``ArrowTableStream``. - """ - live_table = self.stream.as_table(columns={"source": True, "system_tags": True}) - - # Per-row content hashes for dedup - arrow_hasher = self.data_context.arrow_hasher - record_hashes: list[str] = [] - for batch in live_table.to_batches(): - for i in range(len(batch)): - record_hashes.append( - arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() - ) - - live_with_hash = live_table.add_column( - 0, - self.HASH_COLUMN_NAME, - pa.array(record_hashes, type=pa.large_string()), - ) - - self._cache_database.add_records( - self.cache_path, - live_with_hash, - record_id_column=self.HASH_COLUMN_NAME, - skip_duplicates=True, - ) - self._cache_database.flush() - - # Load all cached records (union of current + prior runs) - all_records = self._cache_database.get_all_records(self.cache_path) - assert all_records is not None, ( - "Cache should contain records after storing live data." - ) - - tag_keys = self.stream.keys()[0] - return ArrowTableStream(all_records, tag_columns=tag_keys) - - def _ensure_stream(self) -> None: - """Build the cached stream on first access.""" - if self._cached_stream is None: - self._cached_stream = self._build_cached_stream() - self._update_modified_time() - - # ------------------------------------------------------------------------- - # Stream interface overrides - # ------------------------------------------------------------------------- - - def run(self) -> None: - """Eagerly populate the cache with live stream data.""" - self._ensure_stream() - - def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: - self._ensure_stream() - assert self._cached_stream is not None - return self._cached_stream.iter_packets() - - def as_table( - self, - *, - columns: ColumnConfig | dict[str, Any] | None = None, - all_info: bool = False, - ) -> "pa.Table": - self._ensure_stream() - assert self._cached_stream is not None - return self._cached_stream.as_table(columns=columns, all_info=all_info) - - async def async_execute( - self, - inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], - output: WritableChannel[tuple[TagProtocol, PacketProtocol]], - ) -> None: - """Materialize to cache DB, then push cached rows to the output channel.""" - try: - self._ensure_stream() - assert self._cached_stream is not None - for tag, packet in self._cached_stream.iter_packets(): - await output.send((tag, packet)) - finally: - await output.close() - - def get_all_records(self) -> "pa.Table | None": - """Retrieve all stored records from the cache database.""" - return self._cache_database.get_all_records(self.cache_path) +# Pipeline-specific node types are defined here. +# PersistentFunctionNode and PersistentOperatorNode live in core/nodes/. +# SourceNode (thin graph vertex wrapper) lives in core/nodes/source_node.py. diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 54627506..82d066c3 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -339,6 +339,7 @@ class CacheMode(Enum): REPLAY = "replay" + @dataclass(frozen=True, slots=True) class ColumnConfig: """ diff --git a/tests/test_core/sources/test_persistent_source.py b/tests/test_core/sources/test_cached_source.py similarity index 72% rename from tests/test_core/sources/test_persistent_source.py rename to tests/test_core/sources/test_cached_source.py index af0f22dc..5f27075c 100644 --- a/tests/test_core/sources/test_persistent_source.py +++ b/tests/test_core/sources/test_cached_source.py @@ -1,5 +1,5 @@ """ -Tests for PersistentSource covering: +Tests for CachedSource covering: - Construction and transparent StreamProtocol implementation - Cache path scoped to source's content_hash - Cumulative caching: data from prior runs is preserved @@ -17,7 +17,7 @@ import pyarrow as pa import pytest -from orcapod.core.sources import ArrowTableSource, PersistentSource +from orcapod.core.sources import ArrowTableSource, CachedSource from orcapod.core.streams import ArrowTableStream from orcapod.databases import InMemoryArrowDatabase from orcapod.protocols.core_protocols import StreamProtocol @@ -55,29 +55,29 @@ def db(): # --------------------------------------------------------------------------- -class TestPersistentSourceConstruction: +class TestCachedSourceConstruction: def test_source_id_delegated(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.source_id == simple_source.source_id def test_stream_protocol_conformance(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert isinstance(ps, StreamProtocol) def test_pipeline_element_conformance(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert isinstance(ps, PipelineElementProtocol) def test_identity_delegated(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.identity_structure() == simple_source.identity_structure() def test_content_hash_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.content_hash() == simple_source.content_hash() def test_pipeline_hash_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.pipeline_hash() == simple_source.pipeline_hash() @@ -86,16 +86,16 @@ def test_pipeline_hash_matches_source(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceCachePath: +class TestCachedSourceCachePath: def test_cache_path_contains_content_hash(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) path = ps.cache_path content_hash_str = simple_source.content_hash().to_string() assert any(content_hash_str in segment for segment in path) def test_cache_path_prefix(self, simple_source, db): prefix = ("my_project", "v1") - ps = PersistentSource( + ps = CachedSource( simple_source, cache_database=db, cache_path_prefix=prefix ) assert ps.cache_path[:2] == prefix @@ -104,8 +104,8 @@ def test_same_source_same_cache_path(self, simple_table, db): """Identical sources produce the same cache path.""" s1 = ArrowTableSource(simple_table, tag_columns=["name"], source_id="src") s2 = ArrowTableSource(simple_table, tag_columns=["name"], source_id="src") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path == ps2.cache_path def test_same_name_same_schema_same_cache_path(self, db): @@ -114,8 +114,8 @@ def test_same_name_same_schema_same_cache_path(self, db): t2 = pa.table({"k": ["b"], "v": [2]}) s1 = ArrowTableSource(t1, tag_columns=["k"], source_id="s") s2 = ArrowTableSource(t2, tag_columns=["k"], source_id="s") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path == ps2.cache_path def test_different_name_different_cache_path(self, db): @@ -123,8 +123,8 @@ def test_different_name_different_cache_path(self, db): t1 = pa.table({"k": ["a"], "v": [1]}) s1 = ArrowTableSource(t1, tag_columns=["k"], source_id="src_a") s2 = ArrowTableSource(t1, tag_columns=["k"], source_id="src_b") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path != ps2.cache_path def test_unnamed_different_data_different_cache_path(self, db): @@ -133,8 +133,8 @@ def test_unnamed_different_data_different_cache_path(self, db): t2 = pa.table({"k": ["b"], "v": [2]}) s1 = ArrowTableSource(t1, tag_columns=["k"]) s2 = ArrowTableSource(t2, tag_columns=["k"]) - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) assert ps1.cache_path != ps2.cache_path @@ -143,19 +143,19 @@ def test_unnamed_different_data_different_cache_path(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceSchema: +class TestCachedSourceSchema: def test_output_schema_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.output_schema() == simple_source.output_schema() def test_output_schema_with_system_tags(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.output_schema( columns={"system_tags": True} ) == simple_source.output_schema(columns={"system_tags": True}) def test_keys_match_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) assert ps.keys() == simple_source.keys() @@ -164,28 +164,28 @@ def test_keys_match_source(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceStreaming: +class TestCachedSourceStreaming: def test_as_table_matches_source(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) ps_table = ps.as_table() src_table = simple_source.as_table() assert ps_table.num_rows == src_table.num_rows assert set(ps_table.column_names) == set(src_table.column_names) def test_iter_packets_count(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) packets = list(ps.iter_packets()) assert len(packets) == 3 def test_iter_packets_tags_and_packets(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) for tag, packet in ps.iter_packets(): assert "name" in tag.keys() assert "age" in packet.keys() def test_system_tags_preserved(self, simple_source, db): """System tags flow through the cache correctly.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) table = ps.as_table(columns={"system_tags": True}) sys_tag_cols = [ c for c in table.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) @@ -206,7 +206,7 @@ def test_system_tags_preserved(self, simple_source, db): def test_source_info_preserved(self, simple_source, db): """Source info columns flow through the cache correctly.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) table = ps.as_table(columns={"source": True}) source_cols = [ c for c in table.column_names if c.startswith(constants.SOURCE_PREFIX) @@ -219,24 +219,40 @@ def test_source_info_preserved(self, simple_source, db): # --------------------------------------------------------------------------- -class TestPersistentSourceCumulative: +class TestCachedSourceCumulative: def test_dedup_on_same_data(self, simple_source, db): """Running twice with the same data produces no duplicates.""" - ps1 = PersistentSource(simple_source, cache_database=db) - ps1.run() - ps2 = PersistentSource(simple_source, cache_database=db) - ps2.run() + ps1 = CachedSource(simple_source, cache_database=db) + ps1.flow() + ps2 = CachedSource(simple_source, cache_database=db) + ps2.flow() table = ps2.as_table() assert table.num_rows == 3 # no duplicates def test_clear_cache_rebuilds(self, simple_source, db): """clear_cache forces a fresh merge from DB on next access.""" - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) t1 = ps.as_table() ps.clear_cache() t2 = ps.as_table() assert t1.num_rows == t2.num_rows + def test_source_modified_time_triggers_rebuild(self, simple_source, db): + """Updating the wrapped source's modified time triggers a cache rebuild.""" + ps = CachedSource(simple_source, cache_database=db) + # First access: build and cache + t1 = ps.as_table() + assert t1.num_rows == 3 + + # Simulate the source being updated (e.g. new data loaded) + simple_source._update_modified_time() + + # Next access should detect staleness and rebuild + t2 = ps.as_table() + assert t2.num_rows == 3 + # Verify CachedSource's own modified time was updated past the source's + assert not ps.is_stale + def test_cumulative_across_runs(self, db): """Data from different runs accumulates in the cache.""" # Use a single source_id to make them share the same canonical identity @@ -248,19 +264,19 @@ def test_cumulative_across_runs(self, db): # Different data → different content_hash → different cache_paths # So cumulative within the SAME cache_path requires same content_hash - ps1 = PersistentSource(s1, cache_database=db) - ps1.run() + ps1 = CachedSource(s1, cache_database=db) + ps1.flow() assert ps1.as_table().num_rows == 2 # Same data source: should dedup s1_again = ArrowTableSource(t1, tag_columns=["k"], source_id="shared") - ps1_again = PersistentSource(s1_again, cache_database=db) - ps1_again.run() + ps1_again = CachedSource(s1_again, cache_database=db) + ps1_again.flow() assert ps1_again.as_table().num_rows == 2 # Different source (s2) has different cache_path - ps2 = PersistentSource(s2, cache_database=db) - ps2.run() + ps2 = CachedSource(s2, cache_database=db) + ps2.flow() assert ps2.as_table().num_rows == 3 @@ -269,9 +285,9 @@ def test_cumulative_across_runs(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceFieldResolution: +class TestCachedSourceFieldResolution: def test_resolve_field_delegates(self, simple_source, db): - ps = PersistentSource(simple_source, cache_database=db) + ps = CachedSource(simple_source, cache_database=db) value = ps.resolve_field("row_0", "age") expected = simple_source.resolve_field("row_0", "age") assert value == expected @@ -286,7 +302,7 @@ def test_resolve_field_with_record_id_column(self, db): source = ArrowTableSource( table, tag_columns=["user_id"], record_id_column="user_id", source_id="test" ) - ps = PersistentSource(source, cache_database=db) + ps = CachedSource(source, cache_database=db) assert ps.resolve_field("user_id=u1", "score") == 100 @@ -295,9 +311,9 @@ def test_resolve_field_with_record_id_column(self, db): # --------------------------------------------------------------------------- -class TestPersistentSourceIntegration: - def test_join_with_persistent_source(self, db): - """PersistentSource can be joined with another stream.""" +class TestCachedSourceIntegration: + def test_join_with_cached_source(self, db): + """CachedSource can be joined with another stream.""" from orcapod.core.operators import Join t1 = pa.table({"id": [1, 2, 3], "val_a": [10, 20, 30]}) @@ -305,8 +321,8 @@ def test_join_with_persistent_source(self, db): s1 = ArrowTableSource(t1, tag_columns=["id"], source_id="a") s2 = ArrowTableSource(t2, tag_columns=["id"], source_id="b") - ps1 = PersistentSource(s1, cache_database=db) - ps2 = PersistentSource(s2, cache_database=db) + ps1 = CachedSource(s1, cache_database=db) + ps2 = CachedSource(s2, cache_database=db) joined = Join()(ps1, ps2) table = joined.as_table() @@ -314,8 +330,8 @@ def test_join_with_persistent_source(self, db): assert "val_a" in table.column_names assert "val_b" in table.column_names - def test_function_pod_with_persistent_source(self, db): - """PersistentSource works as input to a FunctionPod.""" + def test_function_pod_with_cached_source(self, db): + """CachedSource works as input to a FunctionPod.""" from orcapod.core.function_pod import FunctionPod from orcapod.core.packet_function import PythonPacketFunction @@ -327,10 +343,34 @@ def double_age(age: int) -> int: table = pa.table({"name": ["Alice", "Bob"], "age": [30, 25]}) source = ArrowTableSource(table, tag_columns=["name"], source_id="test") - ps = PersistentSource(source, cache_database=db) + ps = CachedSource(source, cache_database=db) result = pod(ps) packets = list(result.iter_packets()) assert len(packets) == 2 ages = [p.as_dict()["doubled_age"] for _, p in packets] assert sorted(ages) == [50, 60] + + +class TestCachedConvenienceMethod: + """Test the ``RootSource.cached()`` convenience method.""" + + def test_cached_returns_cached_source(self, simple_source, db): + cached = simple_source.cached(cache_database=db) + assert isinstance(cached, CachedSource) + + def test_cached_with_path_prefix(self, simple_source, db): + cached = simple_source.cached( + cache_database=db, + cache_path_prefix=("my", "prefix"), + ) + assert cached.cache_path[:2] == ("my", "prefix") + + def test_cached_data_matches_source(self, simple_source, db): + cached = simple_source.cached(cache_database=db) + original_table = simple_source.as_table() + cached_table = cached.as_table() + + # Same column names and row count + assert set(original_table.column_names) == set(cached_table.column_names) + assert original_table.num_rows == cached_table.num_rows diff --git a/tests/test_core/test_caching_integration.py b/tests/test_core/test_caching_integration.py index 0b103b60..e1b0ab16 100644 --- a/tests/test_core/test_caching_integration.py +++ b/tests/test_core/test_caching_integration.py @@ -2,7 +2,7 @@ Integration tests: all three pod caching strategies working end-to-end. Covers: -1. PersistentSource — always-on cache scoped to content_hash() +1. CachedSource — always-on cache scoped to content_hash() - DeltaTableSource with canonical source_id (defaults to dir name) - Named sources: same name + same schema = same identity (data-independent) - Unnamed sources: identity determined by table hash (data-dependent) @@ -24,7 +24,7 @@ from orcapod.core.nodes import PersistentFunctionNode, PersistentOperatorNode from orcapod.core.operators import Join from orcapod.core.packet_function import PythonPacketFunction -from orcapod.core.sources import ArrowTableSource, DeltaTableSource, PersistentSource +from orcapod.core.sources import ArrowTableSource, DeltaTableSource, CachedSource from orcapod.databases import InMemoryArrowDatabase from orcapod.types import CacheMode @@ -121,7 +121,7 @@ def pod(): # --------------------------------------------------------------------------- -# 1. PersistentSource — source pod caching +# 1. CachedSource — source pod caching # --------------------------------------------------------------------------- @@ -135,39 +135,39 @@ def test_delta_source_id_defaults_to_dir_name(self, clinic_a): def test_different_sources_get_different_cache_paths(self, clinic_a, source_db): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) assert patients.cache_path != labs.cache_path - def test_cache_populates_on_run(self, clinic_a, source_db): + def test_cache_populates_on_flow(self, clinic_a, source_db): patients_path, _ = clinic_a - ps = PersistentSource( + ps = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps.run() + ps.flow() records = ps.get_all_records() assert records is not None assert records.num_rows == 3 def test_dedup_on_rerun(self, clinic_a, source_db): patients_path, _ = clinic_a - ps1 = PersistentSource( + ps1 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps1.run() - ps2 = PersistentSource( + ps1.flow() + ps2 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps2.run() + ps2.flow() assert ps2.get_all_records().num_rows == 3 def test_named_source_same_name_same_schema_same_identity( @@ -176,7 +176,7 @@ def test_named_source_same_name_same_schema_same_identity( """Same dir name + same schema = same content_hash regardless of data.""" patients_path, _ = clinic_a src1 = DeltaTableSource(patients_path, tag_columns=["patient_id"]) - ps1 = PersistentSource(src1, cache_database=source_db) + ps1 = CachedSource(src1, cache_database=source_db) # Overwrite with different data, same schema _write_delta( @@ -192,7 +192,7 @@ def test_named_source_same_name_same_schema_same_identity( mode="overwrite", ) src2 = DeltaTableSource(patients_path, tag_columns=["patient_id"]) - ps2 = PersistentSource(src2, cache_database=source_db) + ps2 = CachedSource(src2, cache_database=source_db) assert src1.source_id == src2.source_id assert ps1.content_hash() == ps2.content_hash() @@ -201,11 +201,11 @@ def test_named_source_same_name_same_schema_same_identity( def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): """New rows from updated data accumulate in the same cache table.""" patients_path, _ = clinic_a - ps1 = PersistentSource( + ps1 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps1.run() + ps1.flow() assert ps1.get_all_records().num_rows == 3 # Update Delta table: add p4 @@ -221,11 +221,11 @@ def test_cumulative_caching_across_data_updates(self, clinic_a, source_db): ), mode="overwrite", ) - ps2 = PersistentSource( + ps2 = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - ps2.run() + ps2.flow() # 3 original + 1 new, existing rows deduped assert ps2.get_all_records().num_rows == 4 @@ -270,11 +270,11 @@ def test_function_node_stores_records( self, clinic_a, source_db, pipeline_db, result_db, pod ): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -300,11 +300,11 @@ def test_cross_source_sharing_same_pipeline_path( patients_b, labs_b = clinic_b # Pipeline A - pa_src = PersistentSource( + pa_src = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - la_src = PersistentSource( + la_src = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -316,11 +316,11 @@ def test_cross_source_sharing_same_pipeline_path( ) # Pipeline B - pb_src = PersistentSource( + pb_src = CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ) - lb_src = PersistentSource( + lb_src = CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -344,11 +344,11 @@ def test_cross_source_records_accumulate_in_shared_table( fn_a = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ), @@ -363,11 +363,11 @@ def test_cross_source_records_accumulate_in_shared_table( fn_b = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ), @@ -388,11 +388,11 @@ def test_cross_source_records_accumulate_in_shared_table( class TestOperatorPodCaching: def _make_joined_streams(self, clinic_a, source_db): patients_path, labs_path = clinic_a - patients = PersistentSource( + patients = CachedSource( DeltaTableSource(patients_path, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_path, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -469,19 +469,19 @@ def test_content_hash_scoping_isolates_source_combinations( patients_a, labs_a = clinic_a patients_b, labs_b = clinic_b - pa_src = PersistentSource( + pa_src = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - la_src = PersistentSource( + la_src = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) - pb_src = PersistentSource( + pb_src = CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ) - lb_src = PersistentSource( + lb_src = CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -511,17 +511,17 @@ def test_full_pipeline_source_to_function_to_operator( self, clinic_a, clinic_b, source_db, pipeline_db, result_db, operator_db, pod ): """ - Full pipeline: DeltaTableSource → PersistentSource → Join → + Full pipeline: DeltaTableSource → CachedSource → Join → PersistentFunctionNode → PersistentOperatorNode (LOG + REPLAY). """ patients_a, labs_a = clinic_a - # Step 1: PersistentSource - patients = PersistentSource( + # Step 1: CachedSource + patients = CachedSource( DeltaTableSource(patients_a, tag_columns=["patient_id"]), cache_database=source_db, ) - labs = PersistentSource( + labs = CachedSource( DeltaTableSource(labs_a, tag_columns=["patient_id"]), cache_database=source_db, ) @@ -563,11 +563,11 @@ def test_full_pipeline_source_to_function_to_operator( fn_node_b = PersistentFunctionNode( function_pod=pod, input_stream=Join()( - PersistentSource( + CachedSource( DeltaTableSource(patients_b, tag_columns=["patient_id"]), cache_database=source_db, ), - PersistentSource( + CachedSource( DeltaTableSource(labs_b, tag_columns=["patient_id"]), cache_database=source_db, ), diff --git a/tests/test_core/test_tracker.py b/tests/test_core/test_tracker.py index 142c54a6..d4d902bc 100644 --- a/tests/test_core/test_tracker.py +++ b/tests/test_core/test_tracker.py @@ -143,6 +143,84 @@ def test_delegates_keys(self): node = SourceNode(stream=stream) assert node.keys() == stream.keys() + def test_delegates_as_table(self): + stream = _make_stream() + node = SourceNode(stream=stream) + node_table = node.as_table() + stream_table = stream.as_table() + assert node_table.equals(stream_table) + + def test_delegates_iter_packets(self): + stream = _make_stream() + node = SourceNode(stream=stream) + node_packets = list(node.iter_packets()) + stream_packets = list(stream.iter_packets()) + assert len(node_packets) == len(stream_packets) + + def test_run_is_noop(self): + stream = _make_stream() + node = SourceNode(stream=stream) + # run() should succeed without side effects + node.run() + # Data is still accessible after run() + assert node.as_table().num_rows == stream.as_table().num_rows + + def test_delegates_data_context_key(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context_key == stream.data_context_key + + def test_delegates_data_context(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context.context_key == stream.data_context_key + + +# --------------------------------------------------------------------------- +# Node context delegation +# --------------------------------------------------------------------------- + + +class TestNodeContextDelegation: + """All node types delegate data_context to their wrapped entity.""" + + def test_source_node_context_matches_stream(self): + stream = _make_stream() + node = SourceNode(stream=stream) + assert node.data_context_key == stream.data_context_key + assert node.data_context.context_key == stream.data_context_key + + def test_function_node_context_matches_pod(self): + stream = _make_stream() + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + node = FunctionNode(function_pod=pod, input_stream=stream) + assert node.data_context_key == pod.data_context_key + assert node.data_context.context_key == pod.data_context_key + + def test_operator_node_context_matches_operator(self): + stream = _make_two_col_stream() + op = SelectTagColumns("id") + node = OperatorNode(operator=op, input_streams=[stream]) + assert node.data_context_key == op.data_context_key + assert node.data_context.context_key == op.data_context_key + + def test_source_node_hash_consistent_with_stream(self): + stream = _make_stream() + node = SourceNode(stream=stream) + # Both should use the same hasher (from the same data context) + assert node.content_hash() == stream.content_hash() + assert node.pipeline_hash() == stream.pipeline_hash() + + def test_function_node_hash_uses_pod_context(self): + stream = _make_stream() + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + node = FunctionNode(function_pod=pod, input_stream=stream) + # Node should produce stable hashes without error + assert node.content_hash() is not None + assert node.pipeline_hash() is not None + # --------------------------------------------------------------------------- # BasicTrackerManager diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 38d1f4bb..2e3bc490 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -2,8 +2,8 @@ Tests for the Pipeline class. Verifies that Pipeline (a GraphTracker subclass) correctly wraps all nodes -as persistent variants during compile(): -- Leaf streams → PersistentSourceNode +during compile(): +- Leaf streams → SourceNode - Function pod invocations → PersistentFunctionNode - Operator invocations → PersistentOperatorNode """ @@ -17,12 +17,16 @@ from orcapod.core.executors import PacketFunctionExecutorBase from orcapod.core.function_pod import FunctionPod -from orcapod.core.nodes import PersistentFunctionNode, PersistentOperatorNode +from orcapod.core.nodes import ( + PersistentFunctionNode, + PersistentOperatorNode, + SourceNode, +) from orcapod.core.operators import Join, MapPackets from orcapod.core.packet_function import PythonPacketFunction from orcapod.core.sources import ArrowTableSource from orcapod.databases import InMemoryArrowDatabase -from orcapod.pipeline import PersistentSourceNode, Pipeline +from orcapod.pipeline import Pipeline from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol # --------------------------------------------------------------------------- @@ -70,7 +74,7 @@ def function_db(): # --------------------------------------------------------------------------- -# Tests: compile wraps leaf streams as PersistentSourceNode +# Tests: compile wraps leaf streams as SourceNode # --------------------------------------------------------------------------- @@ -88,37 +92,14 @@ def test_compile_wraps_leaf_streams_as_persistent_source_node(self, pipeline_db) assert len(pipeline.compiled_nodes) > 0 assert pipeline._node_graph is not None - # The node graph should contain PersistentSourceNode instances + # The node graph should contain SourceNode instances source_nodes = [ n for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) ] assert len(source_nodes) == 2 - def test_persistent_source_node_cache_path_prefix(self, pipeline_db): - src_a, _ = _make_two_sources() - pipeline = Pipeline(name="my_pipeline", pipeline_database=pipeline_db) - - with pipeline: - # Use a simple unary operator to trigger a recording - MapPackets({"value": "val"})(src_a, label="mapper") - - # Find the PersistentSourceNode - assert pipeline._node_graph is not None - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - assert len(source_nodes) == 1 - sn = source_nodes[0] - - # cache_path should start with pipeline name prefix - assert sn.cache_path[:1] == ("my_pipeline",) - assert sn.cache_path[1] == "source" - assert sn.cache_path[2].startswith("node:") - # --------------------------------------------------------------------------- # Tests: compile creates PersistentFunctionNode @@ -360,11 +341,6 @@ def test_pipeline_path_prefix_scoping(self, pipeline_db): # Check function node adder = pipeline.adder assert adder.pipeline_path[0] == "scoped" - assert pipeline._node_graph is not None - # Check source nodes - for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - assert n.cache_path[0] == "scoped" # --------------------------------------------------------------------------- @@ -411,14 +387,6 @@ def test_end_to_end_source_join_function(self, pipeline_db): # Run the pipeline pipeline.run() - assert pipeline._node_graph is not None - # Source nodes should have cached data - for n in pipeline._node_graph.nodes(): - if isinstance(n, PersistentSourceNode): - records = n.get_all_records() - assert records is not None - assert records.num_rows == 2 - # Function node should have results fn_records = pipeline.adder.get_all_records() assert fn_records is not None @@ -548,10 +516,10 @@ def test_second_pipeline_from_first_pipeline_node(self, pipeline_db): 220, ] - # pipe_b's source nodes wrap pipe_a.adder as a PersistentSourceNode + # pipe_b's source nodes wrap pipe_a.adder as a SourceNode assert pipe_b._node_graph is not None source_nodes = [ - n for n in pipe_b._node_graph.nodes() if isinstance(n, PersistentSourceNode) + n for n in pipe_b._node_graph.nodes() if isinstance(n, SourceNode) ] assert len(source_nodes) == 1 @@ -1009,7 +977,7 @@ def test_recompile_preserves_existing_node_objects(self, pipeline_db): assert "renamer" in pipeline.compiled_nodes def test_recompile_preserves_existing_source_nodes(self, pipeline_db): - """PersistentSourceNode objects from first compile survive second compile.""" + """SourceNode objects from first compile survive second compile.""" src_a, src_b = _make_two_sources() pipeline = Pipeline(name="incr_src", pipeline_database=pipeline_db) @@ -1021,7 +989,7 @@ def test_recompile_preserves_existing_source_nodes(self, pipeline_db): first_source_nodes = { id(n) for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) } # Extend with another operation @@ -1031,7 +999,7 @@ def test_recompile_preserves_existing_source_nodes(self, pipeline_db): second_source_nodes = { id(n) for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) + if isinstance(n, SourceNode) } # All original source nodes should be preserved (same object ids) @@ -1076,7 +1044,7 @@ class TestCompileDoesNotTriggerExecution: triggering upstream iter_packets / run / as_table materialisation.""" def test_compile_does_not_trigger_source_materialization(self, pipeline_db): - """Operators followed by a function pod: compile should not execute anything.""" + """Compile should not trigger any computation or database writes.""" src_a, src_b = _make_two_sources() pf = PythonPacketFunction(add_values, output_keys="total") pod = FunctionPod(pf) @@ -1086,23 +1054,9 @@ def test_compile_does_not_trigger_source_materialization(self, pipeline_db): joined = Join()(src_a, src_b) pod(joined, label="adder") - # After compile, persistent source nodes should NOT have materialised - # their cache yet (i.e. _cached_stream is still None). - assert pipeline._node_graph is not None - source_nodes = [ - n - for n in pipeline._node_graph.nodes() - if isinstance(n, PersistentSourceNode) - ] - for sn in source_nodes: - assert sn._cached_stream is None, ( - "PersistentSourceNode should not have materialised during compile()" - ) - - # The pipeline and function databases should be empty — nothing has - # been written yet because no upstream was triggered. - assert pipeline_db.get_all_records(source_nodes[0].cache_path) is None - assert pipeline_db.get_all_records(source_nodes[1].cache_path) is None + # After compile, the pipeline database should be empty — compile() + # only builds the graph, it doesn't execute any nodes. + assert pipeline.adder.get_all_records() is None # Running the pipeline should still work correctly after lazy compile. pipeline.run() @@ -1235,3 +1189,64 @@ def test_per_node_opts_override_pipeline_opts(self, pipeline_db): # Node opts win: executor should have been created with num_cpus=2 assert pipeline.doubler.executor.opts.get("num_cpus") == 2 + + +class TestSourceNodeNoCaching: + """Verify that SourceNode does not cache — caching is a source-level concern.""" + + def test_source_nodes_do_not_write_to_db(self, pipeline_db): + """Source nodes should not write anything to the pipeline DB.""" + src_a, src_b = _make_two_sources() + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(packet_function=pf) + + pipeline = Pipeline(name="test_pipe", pipeline_database=pipeline_db) + with pipeline: + joined = Join()(src_a, src_b) + pod(joined, label="adder") + + pipeline.run() + + # Function node should have computed results (pipeline works) + records = pipeline.adder.get_all_records() + assert records is not None + assert records.num_rows == 2 + + # Source nodes are plain SourceNode — no caching, no DB writes + source_nodes = [ + n + for n in pipeline._node_graph.nodes() + if isinstance(n, SourceNode) + ] + assert len(source_nodes) == 2 + for sn in source_nodes: + assert not hasattr(sn, "cache_path") + assert not hasattr(sn, "get_all_records") + + def test_pipeline_with_cached_source_input(self, pipeline_db): + """CachedSource as pipeline input: source caching + pipeline execution.""" + from orcapod.core.sources import CachedSource + + src_a, src_b = _make_two_sources() + source_db = InMemoryArrowDatabase() + cached_a = CachedSource(src_a, cache_database=source_db) + cached_b = CachedSource(src_b, cache_database=source_db) + + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(packet_function=pf) + + pipeline = Pipeline(name="cached_src", pipeline_database=pipeline_db) + with pipeline: + joined = Join()(cached_a, cached_b) + pod(joined, label="adder") + + pipeline.run() + + # Function node computed results + records = pipeline.adder.get_all_records() + assert records is not None + assert records.num_rows == 2 + + # Source data was cached in source_db (not pipeline_db) + assert cached_a.get_all_records() is not None + assert cached_b.get_all_records() is not None