Skip to content

Commit

Permalink
feat: use full connection data to route I/O (#148)
Browse files Browse the repository at this point in the history
* fix sample components

* make sum variadic

* separate queue and buffer

* all works but loops & variadics together

* fix some tests

* fix some tests

* all tests green

* clean up code a bit

* refactor code

* fix tests

* fix self loops

* fix reused sockets bug

* add distinct loops

* add distinct loops test

* break out some code from run()

* docstring

* improve variadics drawing

* black

* document the deepcopy

* re-arrange connection dataclass and add tests

* consumer -> receiver

* fix typing

* move Connection-related code under component package

* clean up connect()

* cosmetics and typing

* fix linter, make Connection a dataclass again

* fix typing

* add test case for #105

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
  • Loading branch information
ZanSara and masci committed Nov 14, 2023
1 parent 8b61f66 commit 835acff
Show file tree
Hide file tree
Showing 40 changed files with 971 additions and 658 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -133,4 +133,4 @@ dmypy.json
# Canals
drafts/
.canals_debug/
test/**/*.png
test/**/*.png
14 changes: 11 additions & 3 deletions canals/component/component.py
Expand Up @@ -72,6 +72,7 @@
import inspect
from typing import Protocol, runtime_checkable, Any
from types import new_class
from copy import deepcopy

from canals.component.sockets import InputSocket, OutputSocket
from canals.errors import ComponentError
Expand Down Expand Up @@ -121,10 +122,16 @@ def __call__(cls, *args, **kwargs):
# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets

# If the __init__ called component.set_output_types(), __canals_output__ is already populated
# If `component.set_output_types()` was called in the component constructor,
# `__canals_output__` is already populated, no need to do anything.
if not hasattr(instance, "__canals_output__"):
# if the run method was decorated, it has a _output_types_cache field assigned
instance.__canals_output__ = getattr(instance.run, "_output_types_cache", {})
# If that's not the case, we need to populate `__canals_output__`
#
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__canals_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))

# If the __init__ called component.set_input_types(), __canals_input__ is already populated
if not hasattr(instance, "__canals_input__"):
Expand All @@ -134,6 +141,7 @@ def __call__(cls, *args, **kwargs):
param: InputSocket(
name=param,
type=run_signature.parameters[param].annotation,
is_mandatory=run_signature.parameters[param].default == inspect.Parameter.empty,
)
for param in list(run_signature.parameters)[1:] # First is 'self' and it doesn't matter.
}
Expand Down
167 changes: 167 additions & 0 deletions canals/component/connection.py
@@ -0,0 +1,167 @@
import itertools
from typing import Optional, List, Tuple
from dataclasses import dataclass

from canals.component.sockets import InputSocket, OutputSocket
from canals.type_utils import _type_name, _types_are_compatible
from canals.errors import PipelineConnectError


@dataclass
class Connection:
sender: Optional[str]
sender_socket: Optional[OutputSocket]
receiver: Optional[str]
receiver_socket: Optional[InputSocket]

def __post_init__(self):
if self.sender and self.sender_socket and self.receiver and self.receiver_socket:
# Make sure the receiving socket isn't already connected, unless it's variadic. Sending sockets can be
# connected as many times as needed, so they don't need this check
if self.receiver_socket.senders and not self.receiver_socket.is_variadic:
raise PipelineConnectError(
f"Cannot connect '{self.sender}.{self.sender_socket.name}' with '{self.receiver}.{self.receiver_socket.name}': "
f"{self.receiver}.{self.receiver_socket.name} is already connected to {self.receiver_socket.senders}.\n"
)

self.sender_socket.receivers.append(self.receiver)
self.receiver_socket.senders.append(self.sender)

def __repr__(self):
if self.sender and self.sender_socket:
sender_repr = f"{self.sender}.{self.sender_socket.name} ({_type_name(self.sender_socket.type)})"
else:
sender_repr = "input needed"

if self.receiver and self.receiver_socket:
receiver_repr = f"({_type_name(self.receiver_socket.type)}) {self.receiver}.{self.receiver_socket.name}"
else:
receiver_repr = "output"

return f"{sender_repr} --> {receiver_repr}"

def __hash__(self):
"""
Connection is used as a dictionary key in Pipeline, it must be hashable
"""
return hash(
"-".join(
[
self.sender if self.sender else "input",
self.sender_socket.name if self.sender_socket else "",
self.receiver if self.receiver else "output",
self.receiver_socket.name if self.receiver_socket else "",
]
)
)

@property
def is_mandatory(self) -> bool:
"""
Returns True if the connection goes to a mandatory input socket, False otherwise
"""
if self.receiver_socket:
return self.receiver_socket.is_mandatory
return False

@staticmethod
def from_list_of_sockets(
sender_node: str, sender_sockets: List[OutputSocket], receiver_node: str, receiver_sockets: List[InputSocket]
) -> "Connection":
"""
Find one single possible connection between two lists of sockets.
"""
# List all sender/receiver combinations of sockets that match by type
possible_connections = [
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]

# No connections seem to be possible
if not possible_connections:
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)

# Both sockets were specified: explain why the types don't match
if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1:
raise PipelineConnectError(
f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': "
f"their declared input and output types do not match.\n{connections_status_str}"
)

# Not both sockets were specified: explain there's no possible match on any pair
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{sender_node}' with '{receiver_node}': "
f"no matching connections available.\n{connections_status_str}"
)

# There's more than one possible connection
if len(possible_connections) > 1:
# Try to match by name
name_matches = [
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
if len(name_matches) != 1:
# TODO allow for multiple connections at once if there is no ambiguity?
# TODO give priority to sockets that have no default values?
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible "
"between these components. Please specify the connection name, like: "
f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', "
f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}"
)

match = possible_connections[0]
return Connection(sender_node, match[0], receiver_node, match[1])


def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
):
"""
Lists the status of the sockets, for error messages.
"""
sender_sockets_entries = []
for sender_socket in sender_sockets:
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
sender_sockets_list = "\n".join(sender_sockets_entries)

receiver_sockets_entries = []
for receiver_socket in receiver_sockets:
if receiver_socket.senders:
sender_status = f"sent by {','.join(receiver_socket.senders)}"
else:
sender_status = "available"
receiver_sockets_entries.append(
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_list = "\n".join(receiver_sockets_entries)

return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"


def parse_connect_string(connection: str) -> Tuple[str, Optional[str]]:
"""
Returns component-connection pairs from a connect_to/from string
"""
if "." in connection:
split_str = connection.split(".", maxsplit=1)
return (split_str[0], split_str[1])
return connection, None
8 changes: 4 additions & 4 deletions canals/component/sockets.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import get_origin, get_args, List, Type, Union
from typing import get_args, List, Type
import logging
from dataclasses import dataclass, field

Expand All @@ -15,12 +15,11 @@
class InputSocket:
name: str
type: Type
is_optional: bool = field(init=False)
is_mandatory: bool = True
is_variadic: bool = field(init=False)
sender: List[str] = field(default_factory=list)
senders: List[str] = field(default_factory=list)

def __post_init__(self):
self.is_optional = get_origin(self.type) is Union and type(None) in get_args(self.type)
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION
Expand All @@ -39,3 +38,4 @@ def __post_init__(self):
class OutputSocket:
name: str
type: type
receivers: List[str] = field(default_factory=list)
115 changes: 0 additions & 115 deletions canals/pipeline/connections.py

This file was deleted.

11 changes: 5 additions & 6 deletions canals/pipeline/descriptions.py
Expand Up @@ -19,19 +19,18 @@ def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSo
input sockets, including all such sockets with default values.
"""
return {
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.sender]
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic]
for name, data in graph.nodes(data=True)
}


def find_pipeline_outputs(graph) -> Dict[str, List[OutputSocket]]:
def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]:
"""
Collect components that have disconnected output sockets. They define the pipeline output.
"""
return {
node: list(data.get("output_sockets", {}).values())
for node, data in graph.nodes(data=True)
if not graph.out_edges(node)
name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers]
for name, data in graph.nodes(data=True)
}


Expand All @@ -40,7 +39,7 @@ def describe_pipeline_inputs(graph: networkx.MultiDiGraph):
Returns a dictionary with the input names and types that this pipeline accepts.
"""
inputs = {
comp: {socket.name: {"type": socket.type, "is_optional": socket.is_optional} for socket in data}
comp: {socket.name: {"type": socket.type, "is_mandatory": socket.is_mandatory} for socket in data}
for comp, data in find_pipeline_inputs(graph).items()
if data
}
Expand Down
2 changes: 1 addition & 1 deletion canals/pipeline/draw/draw.py
Expand Up @@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.sender and not in_socket.is_optional:
if not in_socket.senders and in_socket.is_mandatory:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
Expand Down

0 comments on commit 835acff

Please sign in to comment.