Skip to content

Commit

Permalink
handle multiple connections between two nodes (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Jun 29, 2024
1 parent f0ea9b7 commit 7ac0260
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
49 changes: 28 additions & 21 deletions python/reactivedataflow/reactivedataflow/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ def add_edge(
if not self._graph.has_node(to_node):
raise NodeNotFoundError(to_node)

self._graph.add_edge(from_node, to_node, from_port=from_port, to_port=to_port)
port_connection = {"from_port": from_port, "to_port": to_port}
if self._graph.has_edge(from_node, to_node):
edge = self._graph.get_edge_data(from_node, to_node)
edge["ports"].append(port_connection)
else:
self._graph.add_edge(from_node, to_node, ports=[port_connection])

return self

Expand Down Expand Up @@ -186,26 +191,28 @@ def build_node_inputs(
for edge in self._graph.edges(data=True):
# Unpack the edge details
from_node, to_node, data = edge
from_port = data.get("from_port") or default_output
to_port = data.get("to_port")

# Find the appropriate observable the "from" side of the edge represents.
input_source = (
inputs[from_node]
if from_node in inputs
else nodes[from_node].output(from_port)
)

if to_port:
# to_port is defined, this is a named input
if to_node not in named_inputs:
named_inputs[to_node] = {}
named_inputs[to_node][to_port] = input_source
else:
# to_port is not defined, this is an array input
if to_node not in array_inputs:
array_inputs[to_node] = []
array_inputs[to_node].append(input_source)
ports = data.get("ports", [])
for port_connection in ports:
from_port = port_connection.get("from_port") or default_output
to_port = port_connection.get("to_port")

# Find the appropriate observable the "from" side of the edge represents.
input_source = (
inputs[from_node]
if from_node in inputs
else nodes[from_node].output(from_port)
)

if to_port:
# to_port is defined, this is a named input
if to_node not in named_inputs:
named_inputs[to_node] = {}
named_inputs[to_node][to_port] = input_source
else:
# to_port is not defined, this is an array input
if to_node not in array_inputs:
array_inputs[to_node] = []
array_inputs[to_node].append(input_source)
return named_inputs, array_inputs

def bind_inputs(
Expand Down
15 changes: 15 additions & 0 deletions python/reactivedataflow/tests/unit/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ def set_value(v):
graph.dispose()


def test_multiple_edges_on_different_ports():
registry = Registry()
define_math_ops(registry)
graph = (
GraphBuilder()
.add_node("c1", "constant", config={"value": 2})
.add_node("m1", "multiply")
.add_edge("c1", "m1", to_port="a")
.add_edge("c1", "m1", to_port="b")
.add_output("result", "m1")
.build(registry=registry)
)
assert graph.output_value("result") == 4


def test_graph_builder():
registry = Registry()
define_math_ops(registry)
Expand Down

0 comments on commit 7ac0260

Please sign in to comment.