Skip to content

Commit

Permalink
Allow for reference-able node-configuration (#253)
Browse files Browse the repository at this point in the history
* reactivedataflow 0.1.3

* Allow primitives in config values.

* cut version
  • Loading branch information
darthtrevino committed Jun 30, 2024
1 parent a01d494 commit f3e1813
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.3"
version = "0.1.4"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand Down
10 changes: 9 additions & 1 deletion python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RequiredNodeInputNotFoundError,
)
from .execution_graph import ExecutionGraph
from .model import Graph, Output
from .model import Graph, Output, ValRef
from .nodes import ExecutionNode, InputNode, Node
from .registry import Registry

Expand Down Expand Up @@ -166,6 +166,14 @@ def build_nodes() -> dict[str, Node]:

registration = registry.get(node["verb"])
node_config = node.get("config", {}) or {}
for key, value in node_config.items():
if isinstance(value, ValRef):
if value.reference:
node_config[key] = config[value.reference]
else:
node_config[key] = value.value
else:
node_config[key] = value

# Set up an execution node
verb = registry.get_verb_function(node["verb"])
Expand Down
12 changes: 11 additions & 1 deletion python/reactivedataflow/reactivedataflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@ class InputNode(BaseModel):
id: str = Field(..., description="Node identifier.")


class ValRef(BaseModel):
"""A model containing either a value or a reference to a global configuration."""

value: Any | None = Field(default=None, description="The value.")
reference: str | None = Field(
default=None, description="The name of the global configuration to reference."
)
type: str | None = Field(default=None, description="The type of the value.")


class Node(BaseModel):
"""Processing Node Model."""

id: str = Field(..., description="Node identifier.")
verb: str = Field(..., description="The verb name to use.")
config: dict[str, Any] = Field(
config: dict[str, str | int | float | bool | ValRef] = Field(
default_factory=dict, description="Configuration parameters."
)

Expand Down
33 changes: 29 additions & 4 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import pytest
import reactivex as rx

from reactivedataflow import Config, GraphBuilder, Input, NamedInputs, Registry, verb
from reactivedataflow import (
Config,
GraphBuilder,
Input,
NamedInputs,
Registry,
verb,
)
from reactivedataflow.errors import (
GraphHasCyclesError,
InputNotFoundError,
Expand All @@ -19,7 +26,7 @@
RequiredNodeConfigNotFoundError,
RequiredNodeInputNotFoundError,
)
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output, ValRef

from .define_math_ops import define_math_ops

Expand Down Expand Up @@ -307,7 +314,7 @@ def test_graph_builder_from_schema():
],
nodes=[
Node(id="c3", verb="constant", config={"value": 3}),
Node(id="c5", verb="constant", config={"value": 5}),
Node(id="c5", verb="constant", config={"value": ValRef(value=5)}),
Node(id="first_add", verb="add"),
Node(id="second_add", verb="add"),
Node(id="product", verb="multiply"),
Expand Down Expand Up @@ -340,6 +347,20 @@ def test_graph_builder_from_schema():
assert graph.output_value("result") == 40


def test_config_reference():
registry = Registry()
define_math_ops(registry)

graph = (
GraphBuilder()
.add_node("c1", "constant", config={"value": ValRef(reference="x")})
.add_output("c1")
.build(registry=registry, config={"x": 1})
)

assert graph.output_value("c1") == 1


def test_strict_mode():
registry = Registry()
define_math_ops(registry)
Expand Down Expand Up @@ -372,7 +393,11 @@ def constant_strict(value: int) -> int:

# Pass in a bad config value to a node
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3})
builder.add_node(
"c1",
"constant_strict",
config={"value": 1, "UNKNOWN": 3},
)
with pytest.raises(NodeConfigNotDefinedError):
builder.build(registry=registry)

Expand Down

0 comments on commit f3e1813

Please sign in to comment.