Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for reference-able node-configuration #253

Merged
merged 4 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.3"
version = "0.1.4"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand Down
10 changes: 9 additions & 1 deletion python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RequiredNodeInputNotFoundError,
)
from .execution_graph import ExecutionGraph
from .model import Graph, Output
from .model import Graph, Output, ValRef
from .nodes import ExecutionNode, InputNode, Node
from .registry import Registry

Expand Down Expand Up @@ -166,6 +166,14 @@ def build_nodes() -> dict[str, Node]:

registration = registry.get(node["verb"])
node_config = node.get("config", {}) or {}
for key, value in node_config.items():
if isinstance(value, ValRef):
if value.reference:
node_config[key] = config[value.reference]
else:
node_config[key] = value.value
else:
node_config[key] = value

# Set up an execution node
verb = registry.get_verb_function(node["verb"])
Expand Down
12 changes: 11 additions & 1 deletion python/reactivedataflow/reactivedataflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@ class InputNode(BaseModel):
id: str = Field(..., description="Node identifier.")


class ValRef(BaseModel):
"""A model containing either a value or a reference to a global configuration."""

value: Any | None = Field(default=None, description="The value.")
reference: str | None = Field(
default=None, description="The name of the global configuration to reference."
)
type: str | None = Field(default=None, description="The type of the value.")


class Node(BaseModel):
"""Processing Node Model."""

id: str = Field(..., description="Node identifier.")
verb: str = Field(..., description="The verb name to use.")
config: dict[str, Any] = Field(
config: dict[str, str | int | float | bool | ValRef] = Field(
default_factory=dict, description="Configuration parameters."
)

Expand Down
33 changes: 29 additions & 4 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import pytest
import reactivex as rx

from reactivedataflow import Config, GraphBuilder, Input, NamedInputs, Registry, verb
from reactivedataflow import (
Config,
GraphBuilder,
Input,
NamedInputs,
Registry,
verb,
)
from reactivedataflow.errors import (
GraphHasCyclesError,
InputNotFoundError,
Expand All @@ -19,7 +26,7 @@
RequiredNodeConfigNotFoundError,
RequiredNodeInputNotFoundError,
)
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output, ValRef

from .define_math_ops import define_math_ops

Expand Down Expand Up @@ -307,7 +314,7 @@ def test_graph_builder_from_schema():
],
nodes=[
Node(id="c3", verb="constant", config={"value": 3}),
Node(id="c5", verb="constant", config={"value": 5}),
Node(id="c5", verb="constant", config={"value": ValRef(value=5)}),
Node(id="first_add", verb="add"),
Node(id="second_add", verb="add"),
Node(id="product", verb="multiply"),
Expand Down Expand Up @@ -340,6 +347,20 @@ def test_graph_builder_from_schema():
assert graph.output_value("result") == 40


def test_config_reference():
registry = Registry()
define_math_ops(registry)

graph = (
GraphBuilder()
.add_node("c1", "constant", config={"value": ValRef(reference="x")})
.add_output("c1")
.build(registry=registry, config={"x": 1})
)

assert graph.output_value("c1") == 1


def test_strict_mode():
registry = Registry()
define_math_ops(registry)
Expand Down Expand Up @@ -372,7 +393,11 @@ def constant_strict(value: int) -> int:

# Pass in a bad config value to a node
builder = GraphBuilder()
builder.add_node("c1", "constant_strict", config={"value": 1, "UNKNOWN": 3})
builder.add_node(
"c1",
"constant_strict",
config={"value": 1, "UNKNOWN": 3},
)
with pytest.raises(NodeConfigNotDefinedError):
builder.build(registry=registry)

Expand Down
Loading