Skip to content

Commit

Permalink
named, array input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jun 28, 2024
1 parent 0bb218e commit 052a14a
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 22 deletions.
10 changes: 10 additions & 0 deletions python/reactivedataflow/reactivedataflow/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,13 @@ def array_input(self) -> ArrayInput | None:
def named_inputs(self) -> NamedInputs | None:
"""Return the named inputs binding."""
return next((p for p in self._bindings if isinstance(p, NamedInputs)), None)

@property
def required_input_names(self) -> set[str]:
"""Return the required named inputs."""
return {p.name for p in self.input if p.required}

@property
def required_config_names(self) -> set[str]:
"""Return the required named inputs."""
return {p.name for p in self.config if p.required}
31 changes: 31 additions & 0 deletions python/reactivedataflow/reactivedataflow/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
"""reactivedataflow Error Types."""


class RequiredNodeInputNotFoundError(ValueError):
"""An exception for required input not found."""

def __init__(self, nid: str, input_name: str):
"""Initialize the RequiredNodeInputNotFoundError."""
super().__init__(f"Node {nid} is missing required input '{input_name}'.")

class RequiredNodeArrayInputNotFoundError(ValueError):
"""An exception for required array input not found."""

def __init__(self, nid: str):
"""Initialize the RequiredNodeInputNotFoundError."""
super().__init__(f"Node {nid} is missing required array input.")


class RequiredNodeConfigNotFoundError(ValueError):
"""An exception for required input not found."""

def __init__(self, nid: str, config_key: str):
"""Initialize the RequiredNodeInputNotFoundError."""
super().__init__(f"Node {nid} is missing required config '{config_key}'.")


class NodeAlreadyDefinedError(ValueError):
"""An exception for adding a node that already exists."""

Expand All @@ -26,6 +49,14 @@ def __init__(self, name: str):
super().__init__(f"Output '{name}' is already defined.")


class InputNotFoundError(ValueError):
"""An exception for input not defined."""

def __init__(self, input_name: str):
"""Initialize the InputNotFoundError."""
super().__init__(f"Input '{input_name}' not found.")


class OutputNotFoundError(ValueError):
"""An exception for output not defined."""

Expand Down
51 changes: 43 additions & 8 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

from .constants import default_output
from .errors import (
InputNotFoundError,
NodeAlreadyDefinedError,
NodeNotFoundError,
OutputAlreadyDefinedError,
RequiredNodeConfigNotFoundError,
RequiredNodeInputNotFoundError,
RequiredNodeArrayInputNotFoundError,
)
from .execution_graph import ExecutionGraph
from .model import Graph, Output
Expand Down Expand Up @@ -191,17 +195,48 @@ def build_node_inputs(
array_inputs[to_node].append(input_source)
return named_inputs, array_inputs

def bind_inputs(
nodes: dict[str, Node],
named_inputs: dict[str, dict[str, rx.Observable[Any]]],
array_inputs: dict[str, list[rx.Observable[Any]]],
):
for nid in self._graph.nodes:
node = nodes[nid]
if isinstance(node, InputNode):
node.attach(inputs[nid])
if isinstance(node, ExecutionNode):
named_in = named_inputs.get(nid)
array_in = array_inputs.get(nid)
node.attach(named_inputs=named_in, array_inputs=array_in)

def validate_inputs():
for nid in self._graph.nodes:
node = self._graph.nodes[nid]
if node.get("input") and nid not in inputs:
raise InputNotFoundError(nid)

nodes = build_nodes()
validate_inputs()
named_inputs, array_inputs = build_node_inputs(nodes)

# Bind the Inputs
bind_inputs(nodes, named_inputs, array_inputs)

# Validate the graph
for nid in self._graph.nodes:
node = nodes[nid]
if isinstance(node, InputNode):
node.attach(inputs[nid])
if isinstance(node, ExecutionNode):
named_in = named_inputs.get(nid)
array_in = array_inputs.get(nid)
node.attach(named_inputs=named_in, array_inputs=array_in)
node = self._graph.nodes[nid]

if not node.get("input"):
# This is not an input node, validate the inputs and config
bindings = registry.get(node["verb"]).bindings
execution_node = nodes[nid]
if isinstance(execution_node, ExecutionNode):
if bindings.array_input and bindings.array_input.required and not execution_node.has_array_input():
raise RequiredNodeArrayInputNotFoundError(nid)
for required_input in bindings.required_input_names:
if not execution_node.has_input(required_input):
raise RequiredNodeInputNotFoundError(nid, required_input)
for required_config in bindings.required_config_names:
if not execution_node.has_config(required_config):
raise RequiredNodeConfigNotFoundError(nid, required_config)

return ExecutionGraph(nodes, self._outputs)
12 changes: 12 additions & 0 deletions python/reactivedataflow/reactivedataflow/nodes/execution_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def dispose(self) -> None:
subscription.dispose()
self._subscriptions = []

def has_input(self, name: str) -> bool:
"""Check if the node has a given input."""
return name in self._named_inputs

def has_config(self, name: str) -> bool:
"""Check if the node has a given config."""
return name in self._config

def has_array_input(self) -> bool:
"""Check if the node has a given array input."""
return len(self._array_inputs) > 0

def attach(
self,
named_inputs: dict[str, rx.Observable[Any]] | None = None,
Expand Down
98 changes: 84 additions & 14 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024 Microsoft Corporation.
"""reactivedataflow Graph Assembler Tests."""
"""reactivedataflow Graph Builder Tests."""

import pytest
import reactivex as rx
Expand All @@ -9,26 +9,96 @@
Registry,
)
from reactivedataflow.errors import (
InputNotFoundError,
NodeAlreadyDefinedError,
NodeNotFoundError,
OutputAlreadyDefinedError,
OutputNotFoundError,
RequiredNodeConfigNotFoundError,
RequiredNodeInputNotFoundError,
RequiredNodeArrayInputNotFoundError,
)
from reactivedataflow.model import Edge, Graph, InputNode, Node, Output

from .define_math_ops import define_math_ops


def test_missing_input_raises_error():
builder = GraphBuilder()
builder.add_input("i1")

with pytest.raises(InputNotFoundError):
builder.build()


def test_missing_node_input_raises_error():
registry = Registry()
define_math_ops(registry)

builder = GraphBuilder()
builder.add_node("n", "multiply")

with pytest.raises(RequiredNodeInputNotFoundError):
builder.build(registry=registry)

builder.add_node("const1", "constant", config={"value": 1})
builder.add_edge(from_node="const1", to_node="n", to_port="a")

with pytest.raises(RequiredNodeInputNotFoundError):
builder.build(registry=registry)

builder.add_node("const2", "constant", config={"value": 2})
builder.add_edge(from_node="const2", to_node="n", to_port="b")
builder.add_output("n")

graph = builder.build(registry=registry)
assert graph.output_value("n") == 2

def test_missing_array_input_raises_error():
registry = Registry()
define_math_ops(registry)

builder = GraphBuilder()
builder.add_node("const1", "constant", config={"value": 1})
builder.add_node("n", "add")
builder.add_output("n")

with pytest.raises(RequiredNodeArrayInputNotFoundError):
builder.build(registry=registry)

builder.add_edge(from_node="const1", to_node="n")
graph = builder.build(registry=registry)
assert graph.output_value("n") == 1


def test_missing_node_config_raises_error():
registry = Registry()
define_math_ops(registry)

builder = GraphBuilder()
builder.add_node("n", "constant")

with pytest.raises(RequiredNodeConfigNotFoundError):
builder.build(registry=registry)

builder = GraphBuilder()
builder.add_node("n", "constant", config={"value": 1})
builder.add_output("n")

graph = builder.build(registry=registry)
assert graph.output_value("n") == 1


def test_double_add_node_raises_error():
registry = Registry()
define_math_ops(registry)

assembler = GraphBuilder().add_node("c1", "constant", config={"value": 1})
builder = GraphBuilder().add_node("c1", "constant", config={"value": 1})
with pytest.raises(NodeAlreadyDefinedError):
assembler.add_node("c1", "constant", config={"value": 2})
assembler.add_node("c1", "constant", config={"value": 2}, override=True)
assembler.add_output("c1")
graph = assembler.build(registry=registry)
builder.add_node("c1", "constant", config={"value": 2})
builder.add_node("c1", "constant", config={"value": 2}, override=True)
builder.add_output("c1")
graph = builder.build(registry=registry)

assert graph.output_value("c1") == 2

Expand Down Expand Up @@ -107,11 +177,11 @@ def set_value(v):

def test_input_node():
registry = Registry()
assembler = GraphBuilder()
assembler.add_input("i").add_output("i").add_output("fail_1", "i", "x")
builder = GraphBuilder()
builder.add_input("i").add_output("i").add_output("fail_1", "i", "x")

subject = rx.subject.BehaviorSubject(1)
graph = assembler.build(registry=registry, inputs={"i": subject})
graph = builder.build(registry=registry, inputs={"i": subject})

with pytest.raises(OutputNotFoundError):
graph.output_value("fail_1")
Expand All @@ -136,7 +206,7 @@ def set_value(v):
graph.dispose()


def test_graph_assembler():
def test_graph_builder():
registry = Registry()
define_math_ops(registry)

Expand Down Expand Up @@ -169,12 +239,12 @@ def test_graph_assembler():
graph.dispose()


def test_graph_assembler_from_schema():
def test_graph_builder_from_schema():
registry = Registry()
define_math_ops(registry)

assembler = GraphBuilder()
assembler.load(
builder = GraphBuilder()
builder.load(
Graph(
inputs=[
InputNode(id="input"),
Expand Down Expand Up @@ -203,7 +273,7 @@ def test_graph_assembler_from_schema():

# Build the graph
input_stream = rx.subject.BehaviorSubject(1)
graph = assembler.build(registry=registry, inputs={"input": input_stream})
graph = builder.build(registry=registry, inputs={"input": input_stream})

with pytest.raises(OutputNotFoundError):
graph.output_value("fail_1")
Expand Down

0 comments on commit 052a14a

Please sign in to comment.