Skip to content

Commit

Permalink
fix error in strict-mode config injection
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jun 29, 2024
1 parent 9bfc808 commit 367e77e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,24 @@ def build_nodes() -> dict[str, Node]:
nodes: dict[str, Node] = {}
for nid in self._graph.nodes:
node = self._graph.nodes[nid]
node_config = node.get("config", {}) or {}

# Check the `input` flag in the nx graph to determine if this is an input node.
if node.get("input"):
nodes[nid] = InputNode(nid)
continue

registration = registry.get(node["verb"])
node_config = node.get("config", {}) or {}

# Set up an execution node
verb = registry.get_verb_function(node["verb"])
node_final_config = {**config, **node_config}
execution_node = ExecutionNode(nid, verb, node_final_config)
node_global_config = {
key: value
for key, value in config.items()
if key in registration.bindings.config_names
}
node_config = {**node_global_config, **node_config}

execution_node = ExecutionNode(nid, verb, node_config)
nodes[nid] = execution_node
return nodes

Expand Down
8 changes: 8 additions & 0 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,18 @@ def add_strict(a: int, b: int) -> int:
def constant_strict(value: int) -> int:
return value

# Global Config values aren't strictly checked
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.build(config={"hey": "there"}, registry=registry)

# Pass in a bad config value to a node
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3})
with pytest.raises(NodeConfigNotDefinedError):
builder.build(registry=registry)

# Pass in a bad input port value
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.add_node("c2", "constant_strict", config={"value": 2})
Expand All @@ -366,6 +373,7 @@ def constant_strict(value: int) -> int:
with pytest.raises(NodeInputNotDefinedError):
builder.build(registry=registry)

# Wire in a bad output port
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.add_node("c2", "constant_strict", config={"value": 2})
Expand Down

0 comments on commit 367e77e

Please sign in to comment.