Skip to content

Commit

Permalink
update execution_node tasks to use shallow copies of fn inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jul 1, 2024
1 parent c8e1eb6 commit 5b0ab94
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
24 changes: 24 additions & 0 deletions python/reactivedataflow/reactivedataflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from reactivedataflow.nodes import EmitCondition, FireCondition, VerbInput, VerbOutput

from .constants import default_output
from .utils.equality import IsEqualCheck, default_is_equal


Expand Down Expand Up @@ -61,6 +62,29 @@ def array_input_values_are_defined() -> FireCondition:
T = TypeVar("T")


def array_result_non_empty(name: str = default_output) -> EmitCondition:
"""Create an emit condition to emit when the given array output is non-empty."""

def check_array_results_non_empty(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return (
name in outputs.outputs
and outputs.outputs[name]
and isinstance(outputs.outputs[name], list)
and len(outputs.outputs[name]) > 0
)

return check_array_results_non_empty


def output_is_not_none(name: str) -> EmitCondition:
"""Create an emit condition to emit when the given output is not None."""

def check_output_is_not_none(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return name in outputs.outputs and outputs.outputs[name] is not None

return check_output_is_not_none


def output_changed(
output_name: str, is_equal: IsEqualCheck[T] = default_is_equal
) -> EmitCondition:
Expand Down
21 changes: 10 additions & 11 deletions python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def attach(

def on_named_value(value: Any, name: str) -> None:
self._named_input_values[name] = value
self._schedule_recompute("named_value")
self._schedule_recompute(f"input: {name}")

def on_array_value(value: Any, i: int) -> None:
self._array_input_values[i] = value
self._schedule_recompute("array_value")
self._schedule_recompute(f"array_value@{i}")

# Detach from inputs
self.detach()
Expand All @@ -157,19 +157,18 @@ def on_array_value(value: Any, i: int) -> None:

def _schedule_recompute(self, cause: str | None) -> None:
_log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}")
task = asyncio.create_task(self._recompute())
task.add_done_callback(lambda _: self._tasks.remove(task))
self._tasks.append(task)

async def _recompute(self) -> None:
"""Recompute the node."""
inputs = VerbInput(
config=self._config,
named_inputs=self._named_input_values,
array_inputs=self._array_input_values,
config=self._config.copy(),
named_inputs=self._named_input_values.copy(),
array_inputs=self._array_input_values.copy(),
previous_output={name: obs.value for name, obs in self._outputs.items()},
)
task = asyncio.create_task(self._recompute(inputs))
task.add_done_callback(lambda _: self._tasks.remove(task))
self._tasks.append(task)

async def _recompute(self, inputs: VerbInput) -> None:
"""Recompute the node."""
result = await self._fn(inputs)
if not result.no_output:
for name, value in result.outputs.items():
Expand Down

0 comments on commit 5b0ab94

Please sign in to comment.