Skip to content
23 changes: 23 additions & 0 deletions DESIGN_ISSUES.md
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,26 @@ For derived sources (e.g., `DerivedSource`), the stream may not have a meaningfu
`identity_structure()`. Needs an isinstance check or protocol-based dispatch.

---

## `src/orcapod/core/nodes/` — Config and context delegation chain

### T2 — `orcapod_config` not on any protocol; delegation chain needs review
**Status:** open
**Severity:** medium

Nodes (`SourceNode`, `FunctionNode`, `OperatorNode`) now delegate `data_context` to their
wrapped entity via property overrides, ensuring transparent context pass-through. However,
`orcapod_config` is only on `DataContextMixin` (concrete base), not on any protocol
(`StreamProtocol`, `PodProtocol`, etc.).

Open questions:
1. Should `orcapod_config` be added to a protocol (e.g. `TraceableProtocol`)?
Adding it couples protocol consumers to the `Config` type. Leaving it off means
nodes can't transparently delegate config the same way they delegate context.
2. Should `Pipeline` optionally hold `data_context` and/or `config` to allow
pipeline-level overrides that propagate to all nodes during `compile()`?
3. The current chain is: Pipeline (no context) → Node (delegates to wrapped entity) →
wrapped entity (owns context). Should there be a way for Pipeline to inject a
context override?

---
17 changes: 6 additions & 11 deletions demo_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from orcapod.core.packet_function import PythonPacketFunction
from orcapod.core.sources import ArrowTableSource
from orcapod.databases import DeltaTableDatabase, InMemoryArrowDatabase
from orcapod.pipeline import Pipeline, PersistentSourceNode
from orcapod.core.nodes import SourceNode
from orcapod.pipeline import Pipeline


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -96,10 +97,10 @@ def categorize(risk: float) -> str:
for name, node in pipeline.compiled_nodes.items():
print(f" {name}: {type(node).__name__}")

print("\n Source nodes (PersistentSourceNode):")
print("\n Source nodes (SourceNode):")
for n in pipeline._node_graph.nodes():
if isinstance(n, PersistentSourceNode):
print(f" cache_path: {n.cache_path}")
if isinstance(n, SourceNode):
print(f" label: {n.label}, stream: {type(n.stream).__name__}")

# --- Access nodes by label ---
print(f"\n pipeline.join_data -> {type(pipeline.join_data).__name__}")
Expand Down Expand Up @@ -127,12 +128,6 @@ def categorize(risk: float) -> str:
print(f" {cat_table.to_pandas()[['patient_id', 'category']].to_string(index=False)}")

# --- Show what's in the database ---
print(f"\n Source records in DB:")
for n in pipeline._node_graph.nodes():
if isinstance(n, PersistentSourceNode):
records = n.get_all_records()
print(f" {n.cache_path[-1]}: {records.num_rows} rows")

fn_records = pipeline.compute_risk.get_all_records()
print(f" Function records (compute_risk): {fn_records.num_rows} rows")

Expand Down Expand Up @@ -320,7 +315,7 @@ def categorize(risk: float) -> str:
print("=" * 70)
print("""
Pipeline wraps ALL nodes as persistent variants automatically:
- Leaf streams -> PersistentSourceNode (DB-backed cache)
- Leaf streams -> SourceNode (graph vertex wrapper, no caching)
- Operator calls -> PersistentOperatorNode (DB-backed cache)
- Function pod calls -> PersistentFunctionNode (DB-backed cache)

Expand Down
17 changes: 9 additions & 8 deletions src/orcapod/core/nodes/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(
input_stream: StreamProtocol,
tracker_manager: TrackerManagerProtocol | None = None,
label: str | None = None,
data_context: str | contexts.DataContext | None = None,
config: Config | None = None,
):
if tracker_manager is None:
Expand All @@ -79,11 +78,7 @@ def __init__(

# FunctionPod used for the `producer` property and pipeline identity
self._function_pod = function_pod
super().__init__(
label=label,
data_context=data_context,
config=config,
)
super().__init__(label=label, config=config)

# validate the input stream
_, incoming_packet_types = input_stream.output_schema()
Expand Down Expand Up @@ -118,6 +113,14 @@ def __init__(
def producer(self) -> FunctionPodProtocol:
return self._function_pod

@property
def data_context(self) -> contexts.DataContext:
return contexts.resolve_context(self._function_pod.data_context_key)

@property
def data_context_key(self) -> str:
return self._function_pod.data_context_key

@property
def executor(self) -> PacketFunctionExecutorProtocol | None:
"""The executor set on the underlying packet function."""
Expand Down Expand Up @@ -432,15 +435,13 @@ def __init__(
pipeline_path_prefix: tuple[str, ...] = (),
tracker_manager: TrackerManagerProtocol | None = None,
label: str | None = None,
data_context: str | contexts.DataContext | None = None,
config: Config | None = None,
):
super().__init__(
function_pod=function_pod,
input_stream=input_stream,
tracker_manager=tracker_manager,
label=label,
data_context=data_context,
config=config,
)

Expand Down
17 changes: 9 additions & 8 deletions src/orcapod/core/nodes/operator_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
input_streams: tuple[StreamProtocol, ...] | list[StreamProtocol],
tracker_manager: TrackerManagerProtocol | None = None,
label: str | None = None,
data_context: str | contexts.DataContext | None = None,
config: Config | None = None,
):
if tracker_manager is None:
Expand All @@ -60,11 +59,7 @@ def __init__(
self._operator = operator
self._input_streams = tuple(input_streams)

super().__init__(
label=label,
data_context=data_context,
config=config,
)
super().__init__(label=label, config=config)

# Validate inputs eagerly
self._operator.validate_inputs(*self._input_streams)
Expand Down Expand Up @@ -92,6 +87,14 @@ def pipeline_identity_structure(self) -> Any:
def producer(self) -> OperatorPodProtocol:
return self._operator

@property
def data_context(self) -> contexts.DataContext:
return contexts.resolve_context(self._operator.data_context_key)

@property
def data_context_key(self) -> str:
return self._operator.data_context_key

@property
def upstreams(self) -> tuple[StreamProtocol, ...]:
return self._input_streams
Expand Down Expand Up @@ -222,15 +225,13 @@ def __init__(
pipeline_path_prefix: tuple[str, ...] = (),
tracker_manager: TrackerManagerProtocol | None = None,
label: str | None = None,
data_context: str | contexts.DataContext | None = None,
config: Config | None = None,
):
super().__init__(
operator=operator,
input_streams=input_streams,
tracker_manager=tracker_manager,
label=label,
data_context=data_context,
config=config,
)

Expand Down
15 changes: 9 additions & 6 deletions src/orcapod/core/nodes/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ def __init__(
self,
stream: cp.StreamProtocol,
label: str | None = None,
data_context: str | contexts.DataContext | None = None,
config: Config | None = None,
):
super().__init__(
label=label,
data_context=data_context,
config=config,
)
super().__init__(label=label, config=config)
self.stream = stream

@property
def data_context(self) -> contexts.DataContext:
return contexts.resolve_context(self.stream.data_context_key)

@property
def data_context_key(self) -> str:
return self.stream.data_context_key

def computed_label(self) -> str | None:
return self.stream.label

Expand Down
4 changes: 2 additions & 2 deletions src/orcapod/core/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from .base import RootSource
from .arrow_table_source import ArrowTableSource
from .cached_source import CachedSource
from .csv_source import CSVSource
from .data_frame_source import DataFrameSource
from .delta_table_source import DeltaTableSource
from .derived_source import DerivedSource
from .dict_source import DictSource
from .list_source import ListSource
from .persistent_source import PersistentSource
from .source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry

__all__ = [
"RootSource",
"ArrowTableSource",
"CachedSource",
"CSVSource",
"DataFrameSource",
"DeltaTableSource",
"DerivedSource",
"DictSource",
"ListSource",
"PersistentSource",
"SourceRegistry",
"GLOBAL_SOURCE_REGISTRY",
]
35 changes: 34 additions & 1 deletion src/orcapod/core/sources/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from orcapod.core.streams.base import StreamBase
from orcapod.errors import FieldNotResolvableError
from orcapod.protocols.core_protocols import StreamProtocol

if TYPE_CHECKING:
from orcapod.core.sources.cached_source import CachedSource
from orcapod.protocols.database_protocols import ArrowDatabaseProtocol


class RootSource(StreamBase):
"""
Expand Down Expand Up @@ -137,3 +141,32 @@ def producer(self) -> None:
def upstreams(self) -> tuple[StreamProtocol, ...]:
"""Sources have no upstream dependencies."""
return ()

# -------------------------------------------------------------------------
# Convenience — caching
# -------------------------------------------------------------------------

def cached(
self,
cache_database: ArrowDatabaseProtocol,
cache_path_prefix: tuple[str, ...] = (),
**kwargs: Any,
) -> CachedSource:
"""Return a ``CachedSource`` wrapping this source.
Comment on lines +149 to +155
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RootSource.cached() is annotated to return "RootSource" and accepts cache_database: Any, but it always constructs and returns a CachedSource and requires an ArrowDatabaseProtocol. Tightening the return type (to CachedSource or RootSource & CachedSource union) and parameter type will improve type-safety for callers and align with the actual behavior.

Copilot uses AI. Check for mistakes.

Args:
cache_database: Database to store cached records in.
cache_path_prefix: Path prefix for the cache table.
**kwargs: Additional keyword arguments passed to ``CachedSource``.

Returns:
A ``CachedSource`` that caches this source's output.
"""
from orcapod.core.sources.cached_source import CachedSource

return CachedSource(
source=self,
cache_database=cache_database,
cache_path_prefix=cache_path_prefix,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,29 @@
logger = logging.getLogger(__name__)


class PersistentSource(RootSource):
"""
DB-backed wrapper around a RootSource that caches every packet.
class CachedSource(RootSource):
"""DB-backed wrapper around a ``RootSource`` that caches every packet.

Implements StreamProtocol transparently so downstream consumers
Implements ``StreamProtocol`` transparently so downstream consumers
are unaware of caching. Cache table is scoped to the source's
``content_hash()`` — each unique source gets its own table.

Behavior
--------
- Cache is **always on**.
- On first access, live source data is stored in the cache table
(deduped by per-row content hash).
- Returns the union of all cached data (cumulative across runs).

Semantic guarantee
------------------
The cache is a correct cumulative record. The union of cache + live
packets is the full set of data ever available from that source.
Behavior:
- Cache is **always on**.
- On first access, live source data is stored in the cache table
(deduped by per-row content hash).
- Returns the union of all cached data (cumulative across runs).

Semantic guarantee:
The cache is a correct cumulative record. The union of cache + live
packets is the full set of data ever available from that source.

Example::

source = ArrowTableSource(table, tag_columns=["id"])
cached = CachedSource(source, cache_database=db)
# or equivalently:
cached = source.cached(cache_database=db)
"""

HASH_COLUMN_NAME = "_record_hash"
Expand Down Expand Up @@ -153,8 +157,23 @@ def _build_merged_stream(self) -> ArrowTableStream:
tag_keys = self._source.keys()[0]
return ArrowTableStream(all_records, tag_columns=tag_keys)

@property
def is_stale(self) -> bool:
"""True if the wrapped source has been modified since the last build.

Overrides ``StreamBase.is_stale`` because CachedSource is a RootSource
(no upstreams/producer) yet still depends on the wrapped ``_source``.
"""
own_time = self.last_modified
if own_time is None:
return True
src_time = self._source.last_modified
return src_time is None or src_time > own_time

def _ensure_stream(self) -> None:
"""Build the merged stream on first access."""
"""Build the merged stream on first access, or rebuild if source is stale."""
if self._cached_stream is not None and self.is_stale:
self._cached_stream = None
if self._cached_stream is None:
self._cached_stream = self._build_merged_stream()
self._update_modified_time()
Expand All @@ -163,10 +182,6 @@ def clear_cache(self) -> None:
"""Discard in-memory cached stream (forces rebuild on next access)."""
self._cached_stream = None

def run(self) -> None:
"""Eagerly populate the cache with live source data."""
self._ensure_stream()

def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
self._ensure_stream()
assert self._cached_stream is not None
Expand Down
2 changes: 0 additions & 2 deletions src/orcapod/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .graph import Pipeline
from .nodes import PersistentSourceNode
from .orchestrator import AsyncPipelineOrchestrator

__all__ = [
"AsyncPipelineOrchestrator",
"Pipeline",
"PersistentSourceNode",
]
Loading
Loading