From 52d87e1a82e82c1d8f22d064bc80ca752588563a Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Mon, 1 Jul 2024 15:24:01 -0700 Subject: [PATCH] Use shallow-copies of node-inputs on recompute (#257) * update execution_node tasks to use shallow copies of fn inputs * use toposort in drain order * version bump --- python/reactivedataflow/pyproject.toml | 2 +- .../reactivedataflow/conditions.py | 24 +++++++++++++++++++ .../reactivedataflow/execution_graph.py | 13 ++++++---- .../reactivedataflow/graph_builder.py | 4 +++- .../reactivedataflow/nodes/execution_node.py | 21 ++++++++-------- 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/python/reactivedataflow/pyproject.toml b/python/reactivedataflow/pyproject.toml index 4335fdc6..059390a9 100644 --- a/python/reactivedataflow/pyproject.toml +++ b/python/reactivedataflow/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reactivedataflow" -version = "0.1.7" +version = "0.1.8" description = "Reactive Dataflow Graphs" license = "MIT" authors = ["Chris Trevino "] diff --git a/python/reactivedataflow/reactivedataflow/conditions.py b/python/reactivedataflow/reactivedataflow/conditions.py index cdcc3d1b..ec77fd9d 100644 --- a/python/reactivedataflow/reactivedataflow/conditions.py +++ b/python/reactivedataflow/reactivedataflow/conditions.py @@ -5,6 +5,7 @@ from reactivedataflow.nodes import EmitCondition, FireCondition, VerbInput, VerbOutput +from .constants import default_output from .utils.equality import IsEqualCheck, default_is_equal @@ -61,6 +62,29 @@ def array_input_values_are_defined() -> FireCondition: T = TypeVar("T") +def array_result_non_empty(name: str = default_output) -> EmitCondition: + """Create an emit condition to emit when the given array output is non-empty.""" + + def check_array_results_non_empty(_inputs: VerbInput, outputs: VerbOutput) -> bool: + return ( + name in outputs.outputs + and outputs.outputs[name] + and isinstance(outputs.outputs[name], list) + and len(outputs.outputs[name]) > 0 + ) + + return check_array_results_non_empty + + +def output_is_not_none(name: str) -> EmitCondition: + """Create an emit condition to emit when the given output is not None.""" + + def check_output_is_not_none(_inputs: VerbInput, outputs: VerbOutput) -> bool: + return name in outputs.outputs and outputs.outputs[name] is not None + + return check_output_is_not_none + + def output_changed( output_name: str, is_equal: IsEqualCheck[T] = default_is_equal ) -> EmitCondition: diff --git a/python/reactivedataflow/reactivedataflow/execution_graph.py b/python/reactivedataflow/reactivedataflow/execution_graph.py index 52376bb4..59dfb40b 100644 --- a/python/reactivedataflow/reactivedataflow/execution_graph.py +++ b/python/reactivedataflow/reactivedataflow/execution_graph.py @@ -1,7 +1,6 @@ # Copyright (c) 2024 Microsoft Corporation. """The reactivedataflow Library.""" -import asyncio from typing import Any import reactivex as rx @@ -16,16 +15,21 @@ class ExecutionGraph: _nodes: dict[str, Node] _outputs: dict[str, Output] + _order: list[str] - def __init__(self, nodes: dict[str, Node], outputs: dict[str, Output]): + def __init__( + self, nodes: dict[str, Node], outputs: dict[str, Output], order: list[str] + ): """Initialize the execution graph. Args: nodes: The nodes in the graph. outputs: The outputs of the graph. + order: The topological order of the nodes, starting with input nodes. """ self._nodes = nodes self._outputs = outputs + self._order = order async def dispose(self) -> None: """Dispose of all nodes.""" @@ -35,9 +39,8 @@ async def dispose(self) -> None: async def drain(self) -> None: """Drain the task queue.""" - drains = [node.drain() for node in self._nodes.values()] - if len(drains) > 0: - await asyncio.gather(*drains) + for node_id in self._order: + await self._nodes[node_id].drain() def output(self, name: str) -> rx.Observable[Any]: """Read the output of a node.""" diff --git a/python/reactivedataflow/reactivedataflow/graph_builder.py b/python/reactivedataflow/reactivedataflow/graph_builder.py index 0cc74882..468184b6 100644 --- a/python/reactivedataflow/reactivedataflow/graph_builder.py +++ b/python/reactivedataflow/reactivedataflow/graph_builder.py @@ -303,4 +303,6 @@ def validate_node_requirements(): validate_node_requirements() - return ExecutionGraph(nodes, self._outputs) + visit_order = list(nx.topological_sort(self._graph)) + + return ExecutionGraph(nodes, self._outputs, visit_order) diff --git a/python/reactivedataflow/reactivedataflow/nodes/execution_node.py b/python/reactivedataflow/reactivedataflow/nodes/execution_node.py index d41da971..b1fb0fb6 100644 --- a/python/reactivedataflow/reactivedataflow/nodes/execution_node.py +++ b/python/reactivedataflow/reactivedataflow/nodes/execution_node.py @@ -133,11 +133,11 @@ def attach( def on_named_value(value: Any, name: str) -> None: self._named_input_values[name] = value - self._schedule_recompute("named_value") + self._schedule_recompute(f"input: {name}") def on_array_value(value: Any, i: int) -> None: self._array_input_values[i] = value - self._schedule_recompute("array_value") + self._schedule_recompute(f"array_value@{i}") # Detach from inputs self.detach() @@ -157,19 +157,18 @@ def on_array_value(value: Any, i: int) -> None: def _schedule_recompute(self, cause: str | None) -> None: _log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}") - task = asyncio.create_task(self._recompute()) - task.add_done_callback(lambda _: self._tasks.remove(task)) - self._tasks.append(task) - - async def _recompute(self) -> None: - """Recompute the node.""" inputs = VerbInput( - config=self._config, - named_inputs=self._named_input_values, - array_inputs=self._array_input_values, + config=self._config.copy(), + named_inputs=self._named_input_values.copy(), + array_inputs=self._array_input_values.copy(), previous_output={name: obs.value for name, obs in self._outputs.items()}, ) + task = asyncio.create_task(self._recompute(inputs)) + task.add_done_callback(lambda _: self._tasks.remove(task)) + self._tasks.append(task) + async def _recompute(self, inputs: VerbInput) -> None: + """Recompute the node.""" result = await self._fn(inputs) if not result.no_output: for name, value in result.outputs.items():