diff --git a/python/reactivedataflow/reactivedataflow/bindings.py b/python/reactivedataflow/reactivedataflow/bindings.py index d6976241..b2f7ba34 100644 --- a/python/reactivedataflow/reactivedataflow/bindings.py +++ b/python/reactivedataflow/reactivedataflow/bindings.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. """reactivedataflow PortMapper class.""" +from functools import cached_property from typing import Any from pydantic import BaseModel, Field @@ -120,42 +121,42 @@ def bindings(self) -> list[Binding]: """Return the bindings.""" return self._bindings - @property + @cached_property def config(self) -> list[Config]: """Return the configuration bindings.""" return [b for b in self.bindings if isinstance(b, Config)] - @property + @cached_property def input(self) -> list[Input]: """Return the input bindings.""" return [b for b in self.bindings if isinstance(b, Input)] - @property + @cached_property def outputs(self) -> list[Output]: """Return the output bindings.""" return [b for b in self._bindings if isinstance(b, Output)] - @property + @cached_property def array_input(self) -> ArrayInput | None: """Return the array input binding.""" return next((p for p in self._bindings if isinstance(p, ArrayInput)), None) - @property + @cached_property def named_inputs(self) -> NamedInputs | None: """Return the named inputs binding.""" return next((p for p in self._bindings if isinstance(p, NamedInputs)), None) - @property + @cached_property def input_names(self) -> set[str]: """Return the names of the inputs.""" return {p.name for p in self.input} - @property + @cached_property def config_names(self) -> set[str]: """Return the names of the config.""" return {p.name for p in self.config} - @property + @cached_property def output_names(self) -> set[str]: """Return the names of the outputs.""" result = {p.name for p in self.outputs} @@ -163,7 +164,7 @@ def output_names(self) -> set[str]: result.add(default_output) return result - @property + @cached_property def required_input_names(self) -> set[str]: """Return the required named inputs.""" result = {p.name for p in self.input if p.required} @@ -171,7 +172,7 @@ def required_input_names(self) -> set[str]: result.update(self.named_inputs.required) return result - @property + @cached_property def required_config_names(self) -> set[str]: """Return the required named inputs.""" return {p.name for p in self.config if p.required} diff --git a/python/reactivedataflow/reactivedataflow/graph_builder.py b/python/reactivedataflow/reactivedataflow/graph_builder.py index 38e7ad1a..5c9b520f 100644 --- a/python/reactivedataflow/reactivedataflow/graph_builder.py +++ b/python/reactivedataflow/reactivedataflow/graph_builder.py @@ -154,17 +154,24 @@ def build_nodes() -> dict[str, Node]: nodes: dict[str, Node] = {} for nid in self._graph.nodes: node = self._graph.nodes[nid] - node_config = node.get("config", {}) or {} - # Check the `input` flag in the nx graph to determine if this is an input node. if node.get("input"): nodes[nid] = InputNode(nid) continue + registration = registry.get(node["verb"]) + node_config = node.get("config", {}) or {} + # Set up an execution node verb = registry.get_verb_function(node["verb"]) - node_final_config = {**config, **node_config} - execution_node = ExecutionNode(nid, verb, node_final_config) + node_global_config = { + key: value + for key, value in config.items() + if key in registration.bindings.config_names + } + node_config = {**node_global_config, **node_config} + + execution_node = ExecutionNode(nid, verb, node_config) nodes[nid] = execution_node return nodes diff --git a/python/reactivedataflow/tests/unit/test_graph_builder.py b/python/reactivedataflow/tests/unit/test_graph_builder.py index dc27f71f..4b04fa35 100644 --- a/python/reactivedataflow/tests/unit/test_graph_builder.py +++ b/python/reactivedataflow/tests/unit/test_graph_builder.py @@ -350,11 +350,18 @@ def add_strict(a: int, b: int) -> int: def constant_strict(value: int) -> int: return value + # Global Config values aren't strictly checked + builder = GraphBuilder() + builder.add_node("c1", "constant_strict", config={"value": 1}) + builder.build(config={"hey": "there"}, registry=registry) + + # Pass in a bad config value to a node builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3}) with pytest.raises(NodeConfigNotDefinedError): builder.build(registry=registry) + # Pass in a bad input port value builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1}) builder.add_node("c2", "constant_strict", config={"value": 2}) @@ -366,6 +373,7 @@ def constant_strict(value: int) -> int: with pytest.raises(NodeInputNotDefinedError): builder.build(registry=registry) + # Wire in a bad output port builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1}) builder.add_node("c2", "constant_strict", config={"value": 2})