Skip to content

Commit

Permalink
Use shallow-copies of node-inputs on recompute (#257)
Browse files Browse the repository at this point in the history
* update execution_node tasks to use shallow copies of fn inputs

* use toposort in drain order

* version bump
  • Loading branch information
darthtrevino committed Jul 1, 2024
1 parent c8e1eb6 commit 52d87e1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.7"
version = "0.1.8"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand Down
24 changes: 24 additions & 0 deletions python/reactivedataflow/reactivedataflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from reactivedataflow.nodes import EmitCondition, FireCondition, VerbInput, VerbOutput

from .constants import default_output
from .utils.equality import IsEqualCheck, default_is_equal


Expand Down Expand Up @@ -61,6 +62,29 @@ def array_input_values_are_defined() -> FireCondition:
T = TypeVar("T")


def array_result_non_empty(name: str = default_output) -> EmitCondition:
"""Create an emit condition to emit when the given array output is non-empty."""

def check_array_results_non_empty(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return (
name in outputs.outputs
and outputs.outputs[name]
and isinstance(outputs.outputs[name], list)
and len(outputs.outputs[name]) > 0
)

return check_array_results_non_empty


def output_is_not_none(name: str) -> EmitCondition:
"""Create an emit condition to emit when the given output is not None."""

def check_output_is_not_none(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return name in outputs.outputs and outputs.outputs[name] is not None

return check_output_is_not_none


def output_changed(
output_name: str, is_equal: IsEqualCheck[T] = default_is_equal
) -> EmitCondition:
Expand Down
13 changes: 8 additions & 5 deletions python/reactivedataflow/reactivedataflow/execution_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
"""The reactivedataflow Library."""

import asyncio
from typing import Any

import reactivex as rx
Expand All @@ -16,16 +15,21 @@ class ExecutionGraph:

_nodes: dict[str, Node]
_outputs: dict[str, Output]
_order: list[str]

def __init__(self, nodes: dict[str, Node], outputs: dict[str, Output]):
def __init__(
self, nodes: dict[str, Node], outputs: dict[str, Output], order: list[str]
):
"""Initialize the execution graph.
Args:
nodes: The nodes in the graph.
outputs: The outputs of the graph.
order: The topological order of the nodes, starting with input nodes.
"""
self._nodes = nodes
self._outputs = outputs
self._order = order

async def dispose(self) -> None:
"""Dispose of all nodes."""
Expand All @@ -35,9 +39,8 @@ async def dispose(self) -> None:

async def drain(self) -> None:
"""Drain the task queue."""
drains = [node.drain() for node in self._nodes.values()]
if len(drains) > 0:
await asyncio.gather(*drains)
for node_id in self._order:
await self._nodes[node_id].drain()

def output(self, name: str) -> rx.Observable[Any]:
"""Read the output of a node."""
Expand Down
4 changes: 3 additions & 1 deletion python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,6 @@ def validate_node_requirements():

validate_node_requirements()

return ExecutionGraph(nodes, self._outputs)
visit_order = list(nx.topological_sort(self._graph))

return ExecutionGraph(nodes, self._outputs, visit_order)
21 changes: 10 additions & 11 deletions python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def attach(

def on_named_value(value: Any, name: str) -> None:
self._named_input_values[name] = value
self._schedule_recompute("named_value")
self._schedule_recompute(f"input: {name}")

def on_array_value(value: Any, i: int) -> None:
self._array_input_values[i] = value
self._schedule_recompute("array_value")
self._schedule_recompute(f"array_value@{i}")

# Detach from inputs
self.detach()
Expand All @@ -157,19 +157,18 @@ def on_array_value(value: Any, i: int) -> None:

def _schedule_recompute(self, cause: str | None) -> None:
_log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}")
task = asyncio.create_task(self._recompute())
task.add_done_callback(lambda _: self._tasks.remove(task))
self._tasks.append(task)

async def _recompute(self) -> None:
"""Recompute the node."""
inputs = VerbInput(
config=self._config,
named_inputs=self._named_input_values,
array_inputs=self._array_input_values,
config=self._config.copy(),
named_inputs=self._named_input_values.copy(),
array_inputs=self._array_input_values.copy(),
previous_output={name: obs.value for name, obs in self._outputs.items()},
)
task = asyncio.create_task(self._recompute(inputs))
task.add_done_callback(lambda _: self._tasks.remove(task))
self._tasks.append(task)

async def _recompute(self, inputs: VerbInput) -> None:
"""Recompute the node."""
result = await self._fn(inputs)
if not result.no_output:
for name, value in result.outputs.items():
Expand Down

0 comments on commit 52d87e1

Please sign in to comment.