From 367e77e29c3016dea91b0ac93da23a348b3ecf36 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Fri, 28 Jun 2024 22:40:03 -0700 Subject: [PATCH] fix error in strict-mode config injection --- .../reactivedataflow/graph_builder.py | 15 +++++++++++---- .../tests/unit/test_graph_builder.py | 8 ++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/reactivedataflow/reactivedataflow/graph_builder.py b/python/reactivedataflow/reactivedataflow/graph_builder.py index 38e7ad1a..5c9b520f 100644 --- a/python/reactivedataflow/reactivedataflow/graph_builder.py +++ b/python/reactivedataflow/reactivedataflow/graph_builder.py @@ -154,17 +154,24 @@ def build_nodes() -> dict[str, Node]: nodes: dict[str, Node] = {} for nid in self._graph.nodes: node = self._graph.nodes[nid] - node_config = node.get("config", {}) or {} - # Check the `input` flag in the nx graph to determine if this is an input node. if node.get("input"): nodes[nid] = InputNode(nid) continue + registration = registry.get(node["verb"]) + node_config = node.get("config", {}) or {} + # Set up an execution node verb = registry.get_verb_function(node["verb"]) - node_final_config = {**config, **node_config} - execution_node = ExecutionNode(nid, verb, node_final_config) + node_global_config = { + key: value + for key, value in config.items() + if key in registration.bindings.config_names + } + node_config = {**node_global_config, **node_config} + + execution_node = ExecutionNode(nid, verb, node_config) nodes[nid] = execution_node return nodes diff --git a/python/reactivedataflow/tests/unit/test_graph_builder.py b/python/reactivedataflow/tests/unit/test_graph_builder.py index dc27f71f..4b04fa35 100644 --- a/python/reactivedataflow/tests/unit/test_graph_builder.py +++ b/python/reactivedataflow/tests/unit/test_graph_builder.py @@ -350,11 +350,18 @@ def add_strict(a: int, b: int) -> int: def constant_strict(value: int) -> int: return value + # Global Config values aren't strictly checked + builder = GraphBuilder() + builder.add_node("c1", "constant_strict", config={"value": 1}) + builder.build(config={"hey": "there"}, registry=registry) + + # Pass in a bad config value to a node builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3}) with pytest.raises(NodeConfigNotDefinedError): builder.build(registry=registry) + # Pass in a bad input port value builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1}) builder.add_node("c2", "constant_strict", config={"value": 2}) @@ -366,6 +373,7 @@ def constant_strict(value: int) -> int: with pytest.raises(NodeInputNotDefinedError): builder.build(registry=registry) + # Wire in a bad output port builder = GraphBuilder() builder.add_node("c1", "constant_strict", config={"value": 1}) builder.add_node("c2", "constant_strict", config={"value": 2})