diff --git a/python/reactivedataflow/pyproject.toml b/python/reactivedataflow/pyproject.toml index cdf89c62..83dfa48c 100644 --- a/python/reactivedataflow/pyproject.toml +++ b/python/reactivedataflow/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reactivedataflow" -version = "0.1.3" +version = "0.1.4" description = "Reactive Dataflow Graphs" license = "MIT" authors = ["Chris Trevino "] diff --git a/python/reactivedataflow/reactivedataflow/graph_builder.py b/python/reactivedataflow/reactivedataflow/graph_builder.py index 6a0c8acc..7ee62dbe 100644 --- a/python/reactivedataflow/reactivedataflow/graph_builder.py +++ b/python/reactivedataflow/reactivedataflow/graph_builder.py @@ -21,7 +21,7 @@ RequiredNodeInputNotFoundError, ) from .execution_graph import ExecutionGraph -from .model import Graph, Output +from .model import Graph, Output, ValRef from .nodes import ExecutionNode, InputNode, Node from .registry import Registry @@ -166,6 +166,14 @@ def build_nodes() -> dict[str, Node]: registration = registry.get(node["verb"]) node_config = node.get("config", {}) or {} + for key, value in node_config.items(): + if isinstance(value, ValRef): + if value.reference: + node_config[key] = config[value.reference] + else: + node_config[key] = value.value + else: + node_config[key] = value # Set up an execution node verb = registry.get_verb_function(node["verb"]) diff --git a/python/reactivedataflow/reactivedataflow/model.py b/python/reactivedataflow/reactivedataflow/model.py index ae68d59a..1ddc6db9 100644 --- a/python/reactivedataflow/reactivedataflow/model.py +++ b/python/reactivedataflow/reactivedataflow/model.py @@ -14,12 +14,22 @@ class InputNode(BaseModel): id: str = Field(..., description="Node identifier.") +class ValRef(BaseModel): + """A model containing either a value or a reference to a global configuration.""" + + value: Any | None = Field(default=None, description="The value.") + reference: str | None = Field( + default=None, description="The name of the global configuration to reference." + ) + type: str | None = Field(default=None, description="The type of the value.") + + class Node(BaseModel): """Processing Node Model.""" id: str = Field(..., description="Node identifier.") verb: str = Field(..., description="The verb name to use.") - config: dict[str, Any] = Field( + config: dict[str, str | int | float | bool | ValRef] = Field( default_factory=dict, description="Configuration parameters." ) diff --git a/python/reactivedataflow/tests/unit/test_graph_builder.py b/python/reactivedataflow/tests/unit/test_graph_builder.py index 5efbb8ac..6825f6d6 100644 --- a/python/reactivedataflow/tests/unit/test_graph_builder.py +++ b/python/reactivedataflow/tests/unit/test_graph_builder.py @@ -4,7 +4,14 @@ import pytest import reactivex as rx -from reactivedataflow import Config, GraphBuilder, Input, NamedInputs, Registry, verb +from reactivedataflow import ( + Config, + GraphBuilder, + Input, + NamedInputs, + Registry, + verb, +) from reactivedataflow.errors import ( GraphHasCyclesError, InputNotFoundError, @@ -19,7 +26,7 @@ RequiredNodeConfigNotFoundError, RequiredNodeInputNotFoundError, ) -from reactivedataflow.model import Edge, Graph, InputNode, Node, Output +from reactivedataflow.model import Edge, Graph, InputNode, Node, Output, ValRef from .define_math_ops import define_math_ops @@ -307,7 +314,7 @@ def test_graph_builder_from_schema(): ], nodes=[ Node(id="c3", verb="constant", config={"value": 3}), - Node(id="c5", verb="constant", config={"value": 5}), + Node(id="c5", verb="constant", config={"value": ValRef(value=5)}), Node(id="first_add", verb="add"), Node(id="second_add", verb="add"), Node(id="product", verb="multiply"), @@ -340,6 +347,20 @@ def test_graph_builder_from_schema(): assert graph.output_value("result") == 40 +def test_config_reference(): + registry = Registry() + define_math_ops(registry) + + graph = ( + GraphBuilder() + .add_node("c1", "constant", config={"value": ValRef(reference="x")}) + .add_output("c1") + .build(registry=registry, config={"x": 1}) + ) + + assert graph.output_value("c1") == 1 + + def test_strict_mode(): registry = Registry() define_math_ops(registry) @@ -372,7 +393,11 @@ def constant_strict(value: int) -> int: # Pass in a bad config value to a node builder = GraphBuilder() - builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3}) + builder.add_node( + "c1", + "constant_strict", + config={"value": 1, "UNKNOWN": 3}, + ) with pytest.raises(NodeConfigNotDefinedError): builder.build(registry=registry)