Skip to content

Commit

Permalink
implement 'config_providers' graph construction option
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jul 8, 2024
1 parent 93665e4 commit 8ead273
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 8 deletions.
8 changes: 4 additions & 4 deletions python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.9"
version = "0.1.10"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand Down Expand Up @@ -42,9 +42,9 @@ format = ['_sort_imports', '_format_code']
test = "pytest tests"

_test_with_coverage = 'coverage run --source=reactivedataflow -m pytest tests/unit'
_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"'
_generate_coverage_xml = 'coverage xml --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"'
_generate_coverage_html = 'coverage html --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"'
_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"'
_generate_coverage_xml = 'coverage xml --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"'
_generate_coverage_html = 'coverage html --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"'
test_coverage = [
'_test_with_coverage',
'_generate_coverage_xml',
Expand Down
18 changes: 18 additions & 0 deletions python/reactivedataflow/reactivedataflow/config_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow ConfigProvider Protocol."""

from typing import Generic, Protocol, TypeVar

T = TypeVar("T", covariant=True)


class ConfigProvider(Protocol, Generic[T]):
"""A protocol for providing configuration values.
ConfigProviders are evaluated lazily, as they have no way of triggering an ExecutionNode update.
When an ExecutionNode receives an update signal, it will invoke the ConfigProvider.get method to get the latest value of the configuration to assemble a VerbInput.
"""

def get(self) -> T:
"""Get the configuration value for the given name."""
...
10 changes: 10 additions & 0 deletions python/reactivedataflow/reactivedataflow/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
"""reactivedataflow Error Types."""


class ConfigReferenceNotFoundError(ValueError):
"""An exception for when a configuration reference is not found in the global configuration."""

def __init__(self, reference: str):
"""Initialize the ConfigReferenceNotFoundError."""
super().__init__(
f"Configuration reference '{reference}' not found in global configuration."
)


class NodeInputNotDefinedError(ValueError):
"""An exception for input not defined."""

Expand Down
31 changes: 28 additions & 3 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import networkx as nx
import reactivex as rx

from .config_provider import ConfigProvider
from .constants import default_output
from .errors import (
ConfigReferenceNotFoundError,
GraphHasCyclesError,
InputNotFoundError,
NodeAlreadyDefinedError,
Expand Down Expand Up @@ -142,18 +144,21 @@ def build(
self,
inputs: dict[str, rx.Observable[Any]] | None = None,
config: dict[str, Any] | None = None,
config_providers: dict[str, ConfigProvider[Any]] | None = None,
registry: Registry | None = None,
) -> ExecutionGraph:
"""Build the graph.
Args:
inputs: The inputs to the graph.
config: The global configuration for the graph.
config_providers: Configuration providers, dict[str, ConfigProvider] (see the ConfigProvider protocol).
registry: The registry to use for verb lookup.
"""
inputs = inputs or {}
registry = registry or Registry.get_instance()
inputs = inputs or {}
config = config or {}
config_providers = config_providers or {}

def build_nodes() -> dict[str, Node]:
nodes: dict[str, Node] = {}
Expand All @@ -166,10 +171,19 @@ def build_nodes() -> dict[str, Node]:

registration = registry.get(node["verb"])
node_config = node.get("config", {}) or {}
node_config_providers: dict[str, ConfigProvider[Any]] = {}

for key, value in node_config.items():
if isinstance(value, ValRef):
if value.reference:
node_config[key] = config[value.reference]
if value.reference in config:
node_config[key] = config[value.reference]
elif value.reference in config_providers:
node_config_providers[key] = config_providers[
value.reference
]
else:
raise ConfigReferenceNotFoundError(key)
else:
node_config[key] = value.value
else:
Expand All @@ -182,9 +196,20 @@ def build_nodes() -> dict[str, Node]:
for key, value in config.items()
if key in registration.ports.config_names
}
node_global_config_providers = {
key: value
for key, value in config_providers.items()
if key in registration.ports.config_names
}
node_config = {**node_global_config, **node_config}
node_config_providers = {
**node_global_config_providers,
**node_config_providers,
}

execution_node = ExecutionNode(nid, verb, node_config)
execution_node = ExecutionNode(
nid, verb, node_config, node_config_providers
)
nodes[nid] = execution_node
return nodes

Expand Down
13 changes: 12 additions & 1 deletion python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import reactivex as rx

from reactivedataflow.config_provider import ConfigProvider
from reactivedataflow.constants import default_output

from .io import VerbInput
Expand All @@ -22,6 +23,7 @@ class ExecutionNode(Node):
_id: str
_fn: VerbFunction
_config: dict[str, Any]
_config_providers: dict[str, ConfigProvider[Any]]

# Input Observables
_named_inputs: dict[str, rx.Observable]
Expand All @@ -39,17 +41,20 @@ def __init__(
nid: str,
fn: VerbFunction,
config: dict[str, Any] | None = None,
config_providers: dict[str, ConfigProvider[Any]] | None = None,
):
"""Initialize the ExecutionNode.
Args:
nid (str): The node identifier.
fn (VerbFunction): The execution logic for the function. The input is a dictionary of input names to their latest values.
config (dict[str, Any], optional): The configuration for the node. Defaults to None.
config_providers (dict[str, ConfigProvider[Any]], optional): The configuration providers for the node. Defaults to None.
"""
self._id = nid
self._fn = fn
self._config = config or {}
self._config_providers = config_providers or {}
# Inputs
self._named_inputs = {}
self._named_input_values = {}
Expand Down Expand Up @@ -157,8 +162,14 @@ def on_array_value(value: Any, i: int) -> None:

def _schedule_recompute(self, cause: str | None) -> None:
_log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}")

# Copy the config; wire in the config providers
config = self._config.copy()
for name, provider in self._config_providers.items():
config[name] = provider.get()

inputs = VerbInput(
config=self._config.copy(),
config=config,
named_inputs=self._named_input_values.copy(),
array_inputs=self._array_input_values.copy(),
previous_output={name: obs.value for name, obs in self._outputs.items()},
Expand Down
2 changes: 2 additions & 0 deletions python/reactivedataflow/reactivedataflow/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow Types."""

from .config_provider import ConfigProvider
from .decorators import AnyFn, Decorator
from .nodes import EmitCondition, FireCondition, VerbFunction
from .ports import PortBinding
Expand All @@ -9,6 +10,7 @@

__all__ = [
"AnyFn",
"ConfigProvider",
"Decorator",
"EmitCondition",
"FireCondition",
Expand Down
46 changes: 46 additions & 0 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
verb,
)
from reactivedataflow.errors import (
ConfigReferenceNotFoundError,
GraphHasCyclesError,
InputNotFoundError,
NodeAlreadyDefinedError,
Expand All @@ -29,6 +30,7 @@
RequiredNodeInputNotFoundError,
)
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output, ValRef
from reactivedataflow.types import ConfigProvider

from .define_math_ops import define_math_ops

Expand Down Expand Up @@ -194,6 +196,21 @@ def test_throws_on_add_edge_with_unknown_nodes():
builder.add_edge(from_node="n2", to_node="n1")


def test_throws_on_unknown_reference():
registry = Registry()
define_math_ops(registry)

with pytest.raises(ConfigReferenceNotFoundError):
(
GraphBuilder()
.add_node(
"c1", "constant", config={"value": ValRef(reference="value_provider")}
)
.add_output("c1")
.build(registry=registry)
)


async def test_simple_graph():
registry = Registry()
define_math_ops(registry)
Expand All @@ -209,6 +226,35 @@ async def test_simple_graph():
await graph.dispose()


async def test_config_provider():
registry = Registry()
define_math_ops(registry)

value = 1

class ValueProvider(ConfigProvider[int]):
def get(self) -> int:
return value

provider = ValueProvider()
graph = (
GraphBuilder()
.add_node(
"c1", "constant", config={"value": ValRef(reference="value_provider")}
)
.add_output("c1")
.build(registry=registry, config_providers={"value_provider": provider})
)
await graph.drain()
assert graph.output_value("c1") == 1
value = 2
assert provider.get() == 2
await graph.drain()
# Value is not pushed
assert graph.output_value("c1") == 1
await graph.dispose()


async def test_math_op_graph():
registry = Registry()
define_math_ops(registry)
Expand Down

0 comments on commit 8ead273

Please sign in to comment.