From 8ead273c55ca40f83a80820a67409884445d59e6 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Mon, 8 Jul 2024 10:24:22 -0700 Subject: [PATCH] implement 'config_providers' graph construction option --- python/reactivedataflow/pyproject.toml | 8 ++-- .../reactivedataflow/config_provider.py | 18 ++++++++ .../reactivedataflow/errors.py | 10 ++++ .../reactivedataflow/graph_builder.py | 31 +++++++++++-- .../reactivedataflow/nodes/execution_node.py | 13 +++++- .../reactivedataflow/types.py | 2 + .../tests/unit/test_graph_builder.py | 46 +++++++++++++++++++ 7 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 python/reactivedataflow/reactivedataflow/config_provider.py diff --git a/python/reactivedataflow/pyproject.toml b/python/reactivedataflow/pyproject.toml index a9c64700..25c6a67b 100644 --- a/python/reactivedataflow/pyproject.toml +++ b/python/reactivedataflow/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reactivedataflow" -version = "0.1.9" +version = "0.1.10" description = "Reactive Dataflow Graphs" license = "MIT" authors = ["Chris Trevino "] @@ -42,9 +42,9 @@ format = ['_sort_imports', '_format_code'] test = "pytest tests" _test_with_coverage = 'coverage run --source=reactivedataflow -m pytest tests/unit' -_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"' -_generate_coverage_xml = 'coverage xml --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"' -_generate_coverage_html = 'coverage html --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py"' +_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"' +_generate_coverage_xml = 'coverage xml --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"' +_generate_coverage_html = 'coverage html --omit="reactivedataflow/nodes/node.py,reactivedataflow/types.py,reactivedataflow/config_provider.py"' test_coverage = [ '_test_with_coverage', '_generate_coverage_xml', diff --git a/python/reactivedataflow/reactivedataflow/config_provider.py b/python/reactivedataflow/reactivedataflow/config_provider.py new file mode 100644 index 00000000..f4f614ca --- /dev/null +++ b/python/reactivedataflow/reactivedataflow/config_provider.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +"""reactivedataflow ConfigProvider Protocol.""" + +from typing import Generic, Protocol, TypeVar + +T = TypeVar("T", covariant=True) + + +class ConfigProvider(Protocol, Generic[T]): + """A protocol for providing configuration values. + + ConfigProviders are evaluated lazily, as they have no way of triggering an ExecutionNode update. + When an ExecutionNode receives an update signal, it will invoke the ConfigProvider.get method to get the latest value of the configuration to assemble a VerbInput. + """ + + def get(self) -> T: + """Get the configuration value for the given name.""" + ... diff --git a/python/reactivedataflow/reactivedataflow/errors.py b/python/reactivedataflow/reactivedataflow/errors.py index 23605c7a..ba4e1ca0 100644 --- a/python/reactivedataflow/reactivedataflow/errors.py +++ b/python/reactivedataflow/reactivedataflow/errors.py @@ -2,6 +2,16 @@ """reactivedataflow Error Types.""" +class ConfigReferenceNotFoundError(ValueError): + """An exception for when a configuration reference is not found in the global configuration.""" + + def __init__(self, reference: str): + """Initialize the ConfigReferenceNotFoundError.""" + super().__init__( + f"Configuration reference '{reference}' not found in global configuration." + ) + + class NodeInputNotDefinedError(ValueError): """An exception for input not defined.""" diff --git a/python/reactivedataflow/reactivedataflow/graph_builder.py b/python/reactivedataflow/reactivedataflow/graph_builder.py index 468184b6..9f4e648d 100644 --- a/python/reactivedataflow/reactivedataflow/graph_builder.py +++ b/python/reactivedataflow/reactivedataflow/graph_builder.py @@ -6,8 +6,10 @@ import networkx as nx import reactivex as rx +from .config_provider import ConfigProvider from .constants import default_output from .errors import ( + ConfigReferenceNotFoundError, GraphHasCyclesError, InputNotFoundError, NodeAlreadyDefinedError, @@ -142,6 +144,7 @@ def build( self, inputs: dict[str, rx.Observable[Any]] | None = None, config: dict[str, Any] | None = None, + config_providers: dict[str, ConfigProvider[Any]] | None = None, registry: Registry | None = None, ) -> ExecutionGraph: """Build the graph. @@ -149,11 +152,13 @@ def build( Args: inputs: The inputs to the graph. config: The global configuration for the graph. + config_providers: Configuration providers, dict[str, ConfigProvider] (see the ConfigProvider protocol). registry: The registry to use for verb lookup. """ - inputs = inputs or {} registry = registry or Registry.get_instance() + inputs = inputs or {} config = config or {} + config_providers = config_providers or {} def build_nodes() -> dict[str, Node]: nodes: dict[str, Node] = {} @@ -166,10 +171,19 @@ def build_nodes() -> dict[str, Node]: registration = registry.get(node["verb"]) node_config = node.get("config", {}) or {} + node_config_providers: dict[str, ConfigProvider[Any]] = {} + for key, value in node_config.items(): if isinstance(value, ValRef): if value.reference: - node_config[key] = config[value.reference] + if value.reference in config: + node_config[key] = config[value.reference] + elif value.reference in config_providers: + node_config_providers[key] = config_providers[ + value.reference + ] + else: + raise ConfigReferenceNotFoundError(key) else: node_config[key] = value.value else: @@ -182,9 +196,20 @@ def build_nodes() -> dict[str, Node]: for key, value in config.items() if key in registration.ports.config_names } + node_global_config_providers = { + key: value + for key, value in config_providers.items() + if key in registration.ports.config_names + } node_config = {**node_global_config, **node_config} + node_config_providers = { + **node_global_config_providers, + **node_config_providers, + } - execution_node = ExecutionNode(nid, verb, node_config) + execution_node = ExecutionNode( + nid, verb, node_config, node_config_providers + ) nodes[nid] = execution_node return nodes diff --git a/python/reactivedataflow/reactivedataflow/nodes/execution_node.py b/python/reactivedataflow/reactivedataflow/nodes/execution_node.py index b1fb0fb6..c9835674 100644 --- a/python/reactivedataflow/reactivedataflow/nodes/execution_node.py +++ b/python/reactivedataflow/reactivedataflow/nodes/execution_node.py @@ -7,6 +7,7 @@ import reactivex as rx +from reactivedataflow.config_provider import ConfigProvider from reactivedataflow.constants import default_output from .io import VerbInput @@ -22,6 +23,7 @@ class ExecutionNode(Node): _id: str _fn: VerbFunction _config: dict[str, Any] + _config_providers: dict[str, ConfigProvider[Any]] # Input Observables _named_inputs: dict[str, rx.Observable] @@ -39,6 +41,7 @@ def __init__( nid: str, fn: VerbFunction, config: dict[str, Any] | None = None, + config_providers: dict[str, ConfigProvider[Any]] | None = None, ): """Initialize the ExecutionNode. @@ -46,10 +49,12 @@ def __init__( nid (str): The node identifier. fn (VerbFunction): The execution logic for the function. The input is a dictionary of input names to their latest values. config (dict[str, Any], optional): The configuration for the node. Defaults to None. + config_providers (dict[str, ConfigProvider[Any]], optional): The configuration providers for the node. Defaults to None. """ self._id = nid self._fn = fn self._config = config or {} + self._config_providers = config_providers or {} # Inputs self._named_inputs = {} self._named_input_values = {} @@ -157,8 +162,14 @@ def on_array_value(value: Any, i: int) -> None: def _schedule_recompute(self, cause: str | None) -> None: _log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}") + + # Copy the config; wire in the config providers + config = self._config.copy() + for name, provider in self._config_providers.items(): + config[name] = provider.get() + inputs = VerbInput( - config=self._config.copy(), + config=config, named_inputs=self._named_input_values.copy(), array_inputs=self._array_input_values.copy(), previous_output={name: obs.value for name, obs in self._outputs.items()}, diff --git a/python/reactivedataflow/reactivedataflow/types.py b/python/reactivedataflow/reactivedataflow/types.py index 6b42edac..bafbaf20 100644 --- a/python/reactivedataflow/reactivedataflow/types.py +++ b/python/reactivedataflow/reactivedataflow/types.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. """reactivedataflow Types.""" +from .config_provider import ConfigProvider from .decorators import AnyFn, Decorator from .nodes import EmitCondition, FireCondition, VerbFunction from .ports import PortBinding @@ -9,6 +10,7 @@ __all__ = [ "AnyFn", + "ConfigProvider", "Decorator", "EmitCondition", "FireCondition", diff --git a/python/reactivedataflow/tests/unit/test_graph_builder.py b/python/reactivedataflow/tests/unit/test_graph_builder.py index b80dea9e..73de1a6f 100644 --- a/python/reactivedataflow/tests/unit/test_graph_builder.py +++ b/python/reactivedataflow/tests/unit/test_graph_builder.py @@ -15,6 +15,7 @@ verb, ) from reactivedataflow.errors import ( + ConfigReferenceNotFoundError, GraphHasCyclesError, InputNotFoundError, NodeAlreadyDefinedError, @@ -29,6 +30,7 @@ RequiredNodeInputNotFoundError, ) from reactivedataflow.model import Edge, Graph, InputNode, Node, Output, ValRef +from reactivedataflow.types import ConfigProvider from .define_math_ops import define_math_ops @@ -194,6 +196,21 @@ def test_throws_on_add_edge_with_unknown_nodes(): builder.add_edge(from_node="n2", to_node="n1") +def test_throws_on_unknown_reference(): + registry = Registry() + define_math_ops(registry) + + with pytest.raises(ConfigReferenceNotFoundError): + ( + GraphBuilder() + .add_node( + "c1", "constant", config={"value": ValRef(reference="value_provider")} + ) + .add_output("c1") + .build(registry=registry) + ) + + async def test_simple_graph(): registry = Registry() define_math_ops(registry) @@ -209,6 +226,35 @@ async def test_simple_graph(): await graph.dispose() +async def test_config_provider(): + registry = Registry() + define_math_ops(registry) + + value = 1 + + class ValueProvider(ConfigProvider[int]): + def get(self) -> int: + return value + + provider = ValueProvider() + graph = ( + GraphBuilder() + .add_node( + "c1", "constant", config={"value": ValRef(reference="value_provider")} + ) + .add_output("c1") + .build(registry=registry, config_providers={"value_provider": provider}) + ) + await graph.drain() + assert graph.output_value("c1") == 1 + value = 2 + assert provider.get() == 2 + await graph.drain() + # Value is not pushed + assert graph.output_value("c1") == 1 + await graph.dispose() + + async def test_math_op_graph(): registry = Registry() define_math_ops(registry)