Skip to content

Commit

Permalink
add default_output as default argument to output conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jul 8, 2024
1 parent cc4ce49 commit c5a8391
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.10"
version = "0.1.11"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand Down
33 changes: 23 additions & 10 deletions python/reactivedataflow/reactivedataflow/conditions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow Firing Conditions."""

import logging
from typing import Any, TypeVar, cast

from reactivedataflow.nodes import EmitCondition, FireCondition, VerbInput, VerbOutput

from .constants import default_output
from .utils.equality import IsEqualCheck, default_is_equal

log = logging.getLogger(__name__)


def _check_array_input_not_empty(inputs: VerbInput):
return len(inputs.array_inputs) > 0 if inputs.array_inputs else False
Expand All @@ -32,19 +35,23 @@ def check_required_inputs(inputs: VerbInput):
def is_input_present(input_name: str):
return _is_value_in_dict(input_name, inputs.named_inputs)

return all(is_input_present(input_name) for input_name in required_inputs)
result = all(is_input_present(input_name) for input_name in required_inputs)
log.debug("...checking required inputs %s: %s", required_inputs, result)
return result

return check_required_inputs


def require_config(*required_inputs: str) -> FireCondition:
def require_config(*required_config: str) -> FireCondition:
"""Create a fire condition to require the given configuration values to be present for the function to fire."""

def check_required_config(inputs: VerbInput):
def is_config_present(input_name: str):
return _is_value_in_dict(input_name, inputs.config)
def is_config_present(config_name: str):
return _is_value_in_dict(config_name, inputs.config)

return all(is_config_present(input_name) for input_name in required_inputs)
result = all(is_config_present(config_name) for config_name in required_config)
log.debug("...checking required config %s: %s", required_config, result)
return result

return check_required_config

Expand All @@ -66,27 +73,31 @@ def array_result_not_empty(name: str = default_output) -> EmitCondition:
"""Create an emit condition to emit when the given array output is non-empty."""

def check_array_results_non_empty(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return bool(
result = bool(
name in outputs.outputs
and outputs.outputs[name]
and isinstance(outputs.outputs[name], list)
and len(outputs.outputs[name]) > 0
)
log.debug("...checking array results not empty: %s", result)
return result

return check_array_results_non_empty


def output_is_not_none(name: str) -> EmitCondition:
def output_is_not_none(name: str = default_output) -> EmitCondition:
"""Create an emit condition to emit when the given output is not None."""

def check_output_is_not_none(_inputs: VerbInput, outputs: VerbOutput) -> bool:
return name in outputs.outputs and outputs.outputs[name] is not None
result = name in outputs.outputs and outputs.outputs[name] is not None
log.debug("...checking output is not None: %s", result)
return result

return check_output_is_not_none


def output_changed(
output_name: str, is_equal: IsEqualCheck[T] = default_is_equal
output_name: str = default_output, is_equal: IsEqualCheck[T] = default_is_equal
) -> EmitCondition:
"""Create an emit condition to emit when the given output has changed."""

Expand All @@ -102,6 +113,8 @@ def check_output_changed(inputs: VerbInput, outputs: VerbOutput):
previous = cast(T, previous)
current = cast(T, current)

return not is_equal(previous, current)
result = not is_equal(previous, current)
log.debug("...checking output changed: %s", result)
return result

return check_output_changed
15 changes: 10 additions & 5 deletions python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,21 @@ def on_array_value(value: Any, i: int) -> None:
def _schedule_recompute(self, cause: str | None) -> None:
_log.debug(f"recompute scheduled for {self._id} due to {cause or 'unknown'}")

# Copy the config; wire in the config providers
# Generate a shallow copy of the inputs and configuration
previous_output = {name: obs.value for name, obs in self._outputs.items()}
named_inputs = self._named_input_values.copy()
array_inputs = self._array_input_values.copy()
config = self._config.copy()
for name, provider in self._config_providers.items():
config[name] = provider.get()
value = provider.get()
_log.debug("inject config from provider %s, value=%s", name, value)
config[name] = value

inputs = VerbInput(
config=config,
named_inputs=self._named_input_values.copy(),
array_inputs=self._array_input_values.copy(),
previous_output={name: obs.value for name, obs in self._outputs.items()},
named_inputs=named_inputs,
array_inputs=array_inputs,
previous_output=previous_output,
)
task = asyncio.create_task(self._recompute(inputs))
task.add_done_callback(lambda _: self._tasks.remove(task))
Expand Down

0 comments on commit c5a8391

Please sign in to comment.