Skip to content

Commit

Permalink
Improved Async Support (#254)
Browse files Browse the repository at this point in the history
* allow reactive dataflows to be used in proper async contexts

* cut version
  • Loading branch information
darthtrevino committed Jul 1, 2024
1 parent f3e1813 commit 3c3e5ce
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 128 deletions.
20 changes: 19 additions & 1 deletion python/reactivedataflow/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions python/reactivedataflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reactivedataflow"
version = "0.1.4"
version = "0.1.5"
description = "Reactive Dataflow Graphs"
license = "MIT"
authors = ["Chris Trevino <chtrevin@microsoft.com>"]
Expand All @@ -20,6 +20,7 @@ pytest = "^8.2.2"
pyright = "^1.1.366"
coverage = "^7.5.3"
numpy = "<2"
pytest-asyncio = "^0.23.7"

[build-system]
requires = ["poetry-core"]
Expand Down Expand Up @@ -136,4 +137,5 @@ convention = "google"
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_mode = "auto"
2 changes: 0 additions & 2 deletions python/reactivedataflow/reactivedataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
connect_output,
emit_conditions,
fire_conditions,
handle_async_output,
verb,
)
from .execution_graph import ExecutionGraph
Expand Down Expand Up @@ -65,6 +64,5 @@
"connect_output",
"emit_conditions",
"fire_conditions",
"handle_async_output",
"verb",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .connect_output import connect_output
from .emit_conditions import emit_conditions
from .fire_conditions import fire_conditions
from .handle_async_output import handle_async_output
from .verb import verb

__all__ = [
Expand All @@ -17,6 +16,5 @@
"connect_output",
"emit_conditions",
"fire_conditions",
"handle_async_output",
"verb",
]
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow Outputs Decorator."""

import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from inspect import iscoroutine
from typing import Any, ParamSpec, TypeVar

from reactivedataflow.nodes import (
Expand All @@ -13,17 +13,21 @@
T = TypeVar("T")


def handle_async_output() -> Callable[[Callable[P, Any]], Callable[P, VerbOutput]]:
def handle_async_output() -> (
Callable[[Callable[P, Any]], Callable[P, Awaitable[VerbOutput]]]
):
"""Unroll async output.
Args:
default_output (bool): The default output of the function.
"""

def wrap_fn(fn: Callable[P, Any]) -> Callable[P, VerbOutput]:
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> VerbOutput:
def wrap_fn(fn: Callable[P, Any]) -> Callable[P, Awaitable[VerbOutput]]:
async def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> VerbOutput:
result = fn(*args, **kwargs)
return asyncio.run(result)
if iscoroutine(result):
return await result
return result

wrapped_fn.__qualname__ = f"{fn.__qualname__}_wrapasync"
return wrapped_fn
Expand Down
12 changes: 10 additions & 2 deletions python/reactivedataflow/reactivedataflow/execution_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
"""The reactivedataflow Library."""

import asyncio
from typing import Any

import reactivex as rx
Expand All @@ -26,10 +27,17 @@ def __init__(self, nodes: dict[str, Node], outputs: dict[str, Output]):
self._nodes = nodes
self._outputs = outputs

def dispose(self) -> None:
async def dispose(self) -> None:
"""Dispose of all nodes."""
for node in self._nodes.values():
node.dispose()
node.detach()
await self.drain()

async def drain(self) -> None:
"""Drain the task queue."""
drains = [node.drain() for node in self._nodes.values()]
if len(drains) > 0:
await asyncio.gather(*drains)

def output(self, name: str) -> rx.Observable[Any]:
"""Read the output of a node."""
Expand Down
97 changes: 50 additions & 47 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,55 @@ def validate_inputs():
if node.get("input") and nid not in inputs:
raise InputNotFoundError(nid)

def validate_node_requirements():
for nid in self._graph.nodes:
node = self._graph.nodes[nid]

if node.get("input"):
# skip input nodes, they've already been validated
continue

# Validate the inputs and config
registration = registry.get(node["verb"])
bindings = registration.ports
execution_node = nodes[nid]

if isinstance(execution_node, ExecutionNode):
input_names = execution_node.input_names
config_names = execution_node.config_names
num_array_inputs = execution_node.num_array_inputs
output_names = execution_node.output_names
array_input = bindings.array_input

if (
array_input
and array_input.required
and num_array_inputs < array_input.required
):
raise RequiredNodeArrayInputNotFoundError(nid)

for required_input in bindings.required_input_names:
if required_input not in input_names:
raise RequiredNodeInputNotFoundError(nid, required_input)

for required_config in bindings.required_config_names:
if required_config not in config_names:
raise RequiredNodeConfigNotFoundError(nid, required_config)

if registration.strict:
# Check that all inputs are accounted for
for input_name in input_names:
if input_name not in bindings.input_names:
raise NodeInputNotDefinedError(input_name)

for config_name in config_names:
if config_name not in bindings.config_names:
raise NodeConfigNotDefinedError(config_name)

for output_name in output_names:
if output_name not in bindings.output_names:
raise NodeOutputNotDefinedError(output_name)

nodes = build_nodes()
validate_inputs()
named_inputs, array_inputs = build_node_inputs(nodes)
Expand All @@ -252,52 +301,6 @@ def validate_inputs():
if not nx.is_directed_acyclic_graph(self._graph):
raise GraphHasCyclesError

for nid in self._graph.nodes:
node = self._graph.nodes[nid]

if node.get("input"):
# skip input nodes, they've already been validated
continue

# Validate the inputs and config
registration = registry.get(node["verb"])
bindings = registration.ports
execution_node = nodes[nid]

if isinstance(execution_node, ExecutionNode):
input_names = execution_node.input_names
config_names = execution_node.config_names
num_array_inputs = execution_node.num_array_inputs
output_names = execution_node.output_names
array_input = bindings.array_input

if (
array_input
and array_input.required
and num_array_inputs < array_input.required
):
raise RequiredNodeArrayInputNotFoundError(nid)

for required_input in bindings.required_input_names:
if required_input not in input_names:
raise RequiredNodeInputNotFoundError(nid, required_input)

for required_config in bindings.required_config_names:
if required_config not in config_names:
raise RequiredNodeConfigNotFoundError(nid, required_config)

if registration.strict:
# Check that all inputs are accounted for
for input_name in input_names:
if input_name not in bindings.input_names:
raise NodeInputNotDefinedError(input_name)

for config_name in config_names:
if config_name not in bindings.config_names:
raise NodeConfigNotDefinedError(config_name)

for output_name in output_names:
if output_name not in bindings.output_names:
raise NodeOutputNotDefinedError(output_name)
validate_node_requirements()

return ExecutionGraph(nodes, self._outputs)
11 changes: 9 additions & 2 deletions python/reactivedataflow/reactivedataflow/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@

from .execution_node import ExecutionNode
from .input_node import InputNode
from .io import EmitMode, InputMode, OutputMode, VerbInput, VerbOutput
from .io import (
EmitMode,
InputMode,
OutputMode,
VerbInput,
VerbOutput,
)
from .node import Node
from .types import EmitCondition, FireCondition, VerbFunction
from .types import AsyncVerbFunction, EmitCondition, FireCondition, VerbFunction

__all__ = [
"AsyncVerbFunction",
"EmitCondition",
"EmitMode",
"ExecutionNode",
Expand Down
36 changes: 27 additions & 9 deletions python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) 2024 Microsoft Corporation.
"""The reactivedataflow ExecutionNode class."""

from typing import Any
import asyncio
from collections.abc import Awaitable
from inspect import iscoroutine
from typing import Any, cast

import reactivex as rx

from reactivedataflow.constants import default_output

from .io import VerbInput
from .io import VerbInput, VerbOutput
from .node import Node
from .types import VerbFunction

Expand All @@ -28,6 +31,7 @@ class ExecutionNode(Node):

# Output Observable
_outputs: dict[str, rx.subject.BehaviorSubject]
_tasks: list

def __init__(
self,
Expand All @@ -53,8 +57,9 @@ def __init__(
self._subscriptions = []
# Output
self._outputs = {}
self._tasks = []
# fire a recompute
self._recompute()
self._schedule_recompute()

def _output(self, name: str) -> rx.subject.BehaviorSubject:
"""Get the subject of a given output."""
Expand All @@ -76,7 +81,7 @@ def config(self) -> dict[str, Any]:
def config(self, value: dict[str, Any]) -> None:
"""Set the configuration of the node."""
self._config = value
self._recompute()
self._schedule_recompute()

def output(self, name: str = default_output) -> rx.Observable[Any]:
"""Get the observable of a given output."""
Expand All @@ -86,7 +91,12 @@ def output_value(self, name: str = default_output) -> Any:
"""Get the observable of a given output."""
return self._output(name).value

def dispose(self) -> None:
async def drain(self) -> None:
"""Drain the tasks."""
if len(self._tasks) > 0:
await asyncio.gather(*self._tasks)

def detach(self) -> None:
"""Detach the node from all inputs."""
if len(self._subscriptions) > 0:
for subscription in self._subscriptions:
Expand Down Expand Up @@ -122,14 +132,14 @@ def attach(

def on_named_value(value: Any, name: str) -> None:
self._named_input_values[name] = value
self._recompute()
self._schedule_recompute()

def on_array_value(value: Any, i: int) -> None:
self._array_input_values[i] = value
self._recompute()
self._schedule_recompute()

# Detach from inputs
self.dispose()
self.detach()

if named_inputs:
for name, source in named_inputs.items():
Expand All @@ -144,7 +154,13 @@ def on_array_value(value: Any, i: int) -> None:
sub = source.subscribe(lambda v, i=i: on_array_value(v, i))
self._subscriptions.append(sub)

def _recompute(self) -> None:
def _schedule_recompute(self) -> None:
task = asyncio.create_task(self._recompute())
task.add_done_callback(lambda _: self._tasks.remove(task))
self._tasks.append(task)

async def _recompute(self) -> None:
"""Recompute the node."""
inputs = VerbInput(
config=self._config,
named_inputs=self._named_input_values,
Expand All @@ -153,6 +169,8 @@ def _recompute(self) -> None:
)

result = self._fn(inputs)
if iscoroutine(result):
result = await cast(Awaitable[VerbOutput], result)
if not result.no_output:
for name, value in result.outputs.items():
self._output(name).on_next(value)
Loading

0 comments on commit 3c3e5ce

Please sign in to comment.