Skip to content

Commit

Permalink
fix error in strict-mode config injection (#249)
Browse files Browse the repository at this point in the history
* fix error in strict-mode config injection

* use cached properties in bindings class
  • Loading branch information
darthtrevino committed Jun 29, 2024
1 parent 9bfc808 commit 0d02b61
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
21 changes: 11 additions & 10 deletions python/reactivedataflow/reactivedataflow/bindings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow PortMapper class."""

from functools import cached_property
from typing import Any

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -120,58 +121,58 @@ def bindings(self) -> list[Binding]:
"""Return the bindings."""
return self._bindings

@property
@cached_property
def config(self) -> list[Config]:
"""Return the configuration bindings."""
return [b for b in self.bindings if isinstance(b, Config)]

@property
@cached_property
def input(self) -> list[Input]:
"""Return the input bindings."""
return [b for b in self.bindings if isinstance(b, Input)]

@property
@cached_property
def outputs(self) -> list[Output]:
"""Return the output bindings."""
return [b for b in self._bindings if isinstance(b, Output)]

@property
@cached_property
def array_input(self) -> ArrayInput | None:
"""Return the array input binding."""
return next((p for p in self._bindings if isinstance(p, ArrayInput)), None)

@property
@cached_property
def named_inputs(self) -> NamedInputs | None:
"""Return the named inputs binding."""
return next((p for p in self._bindings if isinstance(p, NamedInputs)), None)

@property
@cached_property
def input_names(self) -> set[str]:
"""Return the names of the inputs."""
return {p.name for p in self.input}

@property
@cached_property
def config_names(self) -> set[str]:
"""Return the names of the config."""
return {p.name for p in self.config}

@property
@cached_property
def output_names(self) -> set[str]:
"""Return the names of the outputs."""
result = {p.name for p in self.outputs}
if self._has_default_output:
result.add(default_output)
return result

@property
@cached_property
def required_input_names(self) -> set[str]:
"""Return the required named inputs."""
result = {p.name for p in self.input if p.required}
if self.named_inputs:
result.update(self.named_inputs.required)
return result

@property
@cached_property
def required_config_names(self) -> set[str]:
"""Return the required named inputs."""
return {p.name for p in self.config if p.required}
15 changes: 11 additions & 4 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,24 @@ def build_nodes() -> dict[str, Node]:
nodes: dict[str, Node] = {}
for nid in self._graph.nodes:
node = self._graph.nodes[nid]
node_config = node.get("config", {}) or {}

# Check the `input` flag in the nx graph to determine if this is an input node.
if node.get("input"):
nodes[nid] = InputNode(nid)
continue

registration = registry.get(node["verb"])
node_config = node.get("config", {}) or {}

# Set up an execution node
verb = registry.get_verb_function(node["verb"])
node_final_config = {**config, **node_config}
execution_node = ExecutionNode(nid, verb, node_final_config)
node_global_config = {
key: value
for key, value in config.items()
if key in registration.bindings.config_names
}
node_config = {**node_global_config, **node_config}

execution_node = ExecutionNode(nid, verb, node_config)
nodes[nid] = execution_node
return nodes

Expand Down
8 changes: 8 additions & 0 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,18 @@ def add_strict(a: int, b: int) -> int:
def constant_strict(value: int) -> int:
return value

# Global Config values aren't strictly checked
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.build(config={"hey": "there"}, registry=registry)

# Pass in a bad config value to a node
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3})
with pytest.raises(NodeConfigNotDefinedError):
builder.build(registry=registry)

# Pass in a bad input port value
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.add_node("c2", "constant_strict", config={"value": 2})
Expand All @@ -366,6 +373,7 @@ def constant_strict(value: int) -> int:
with pytest.raises(NodeInputNotDefinedError):
builder.build(registry=registry)

# Wire in a bad output port
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1})
builder.add_node("c2", "constant_strict", config={"value": 2})
Expand Down

0 comments on commit 0d02b61

Please sign in to comment.