Skip to content

Commit

Permalink
add cycle detection (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jun 29, 2024
1 parent 6324eb6 commit ae0a157
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
8 changes: 8 additions & 0 deletions python/reactivedataflow/reactivedataflow/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
"""reactivedataflow Error Types."""


class GraphHasCyclesError(ValueError):
"""An exception for a cycle detected in the graph."""

def __init__(self):
"""Initialize the CycleDetectedError."""
super().__init__("Cycle detected in the graph.")


class RequiredNodeInputNotFoundError(ValueError):
"""An exception for required input not found."""

Expand Down
45 changes: 26 additions & 19 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .constants import default_output
from .errors import (
GraphHasCyclesError,
InputNotFoundError,
NodeAlreadyDefinedError,
NodeNotFoundError,
Expand Down Expand Up @@ -222,27 +223,33 @@ def validate_inputs():
bind_inputs(nodes, named_inputs, array_inputs)

# Validate the graph
if not nx.is_directed_acyclic_graph(self._graph):
raise GraphHasCyclesError

for nid in self._graph.nodes:
node = self._graph.nodes[nid]

if not node.get("input"):
# This is not an input node, validate the inputs and config
bindings = registry.get(node["verb"]).bindings
execution_node = nodes[nid]
if isinstance(execution_node, ExecutionNode):
array_input = bindings.array_input

if (
array_input
and array_input.required
and execution_node.num_array_inputs() < array_input.required
):
raise RequiredNodeArrayInputNotFoundError(nid)
for required_input in bindings.required_input_names:
if not execution_node.has_input(required_input):
raise RequiredNodeInputNotFoundError(nid, required_input)
for required_config in bindings.required_config_names:
if not execution_node.has_config(required_config):
raise RequiredNodeConfigNotFoundError(nid, required_config)
if node.get("input"):
# skip input nodes, they've already been validated
continue

# Validate the inputs and config
bindings = registry.get(node["verb"]).bindings
execution_node = nodes[nid]
if isinstance(execution_node, ExecutionNode):
array_input = bindings.array_input

if (
array_input
and array_input.required
and execution_node.num_array_inputs() < array_input.required
):
raise RequiredNodeArrayInputNotFoundError(nid)
for required_input in bindings.required_input_names:
if not execution_node.has_input(required_input):
raise RequiredNodeInputNotFoundError(nid, required_input)
for required_config in bindings.required_config_names:
if not execution_node.has_config(required_config):
raise RequiredNodeConfigNotFoundError(nid, required_config)

return ExecutionGraph(nodes, self._outputs)
15 changes: 15 additions & 0 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from reactivedataflow import GraphBuilder, NamedInputs, Registry, verb
from reactivedataflow.errors import (
GraphHasCyclesError,
InputNotFoundError,
NodeAlreadyDefinedError,
NodeNotFoundError,
Expand Down Expand Up @@ -112,6 +113,20 @@ def test_missing_node_config_raises_error():
assert graph.output_value("n") == 1


def test_cyclic_graph_raises_error():
registry = Registry()
define_math_ops(registry)

builder = GraphBuilder()
builder.add_node("n1", "add")
builder.add_node("n2", "add")
builder.add_edge(from_node="n1", to_node="n2")
builder.add_edge(from_node="n2", to_node="n1")

with pytest.raises(GraphHasCyclesError):
builder.build(registry=registry)


def test_double_add_node_raises_error():
registry = Registry()
define_math_ops(registry)
Expand Down

0 comments on commit ae0a157

Please sign in to comment.