Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions haystack/core/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import collections.abc
from typing import Any, TypeVar, Union, get_args, get_origin

T = TypeVar("T")
Expand Down Expand Up @@ -61,15 +62,37 @@ def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-
sender_args = get_args(sender)
receiver_args = get_args(receiver)

# Handle Callable types
if sender_origin == receiver_origin == collections.abc.Callable:
return _check_callable_compatibility(sender_args, receiver_args)

# Handle bare types
if not sender_args and sender_origin:
sender_args = (Any,)
if not receiver_args and receiver_origin:
receiver_args = (Any,) * (len(sender_args) if sender_args else 1)
if len(sender_args) > len(receiver_args):
return False

return all(_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
return not (len(sender_args) > len(receiver_args)) and all(
_strict_types_are_compatible(*args) for args in zip(sender_args, receiver_args)
)


def _check_callable_compatibility(sender_args, receiver_args):
"""Helper function to check compatibility of Callable types"""
if not receiver_args:
return True
if not sender_args:
sender_args = ([Any] * len(receiver_args[0]), Any)
# Standard Callable has two elements in args: argument list and return type
if len(sender_args) != 2 or len(receiver_args) != 2:
return False
# Return types must be compatible
if not _strict_types_are_compatible(sender_args[1], receiver_args[1]):
return False
# Input Arguments must be of same length
if len(sender_args[0]) != len(receiver_args[0]):
return False
return all(_strict_types_are_compatible(sender_args[0][i], receiver_args[0][i]) for i in range(len(sender_args[0])))


def _type_name(type_):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add compatibility for Callable types.
60 changes: 59 additions & 1 deletion test/core/pipeline/test_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union

import pytest

Expand Down Expand Up @@ -384,6 +385,7 @@ def test_container_of_primitive_to_bare_container_strict(sender_type, receiver_t
pytest.param(Dict[Any, Any], Dict, id="dict-of-any-to-bare-dict"),
pytest.param(Set[Any], Set, id="set-of-any-to-bare-set"),
pytest.param(Tuple[Any], Tuple, id="tuple-of-any-to-bare-tuple"),
pytest.param(Callable[[Any], Any], Callable, id="callable-of-any-to-bare-callable"),
],
)
def test_container_of_any_to_bare_container_strict(sender_type, receiver_type):
Expand Down Expand Up @@ -438,3 +440,59 @@ def test_nested_container_compatibility(sender_type, receiver_type):
assert _types_are_compatible(sender_type, receiver_type)
# Bare container types should not be compatible with their typed counterparts
assert not _types_are_compatible(receiver_type, sender_type)


@pytest.mark.parametrize(
"sender_type,receiver_type",
[
pytest.param(Callable[[int, str], bool], Callable, id="callable-to-bare-callable"),
pytest.param(Callable[[List], int], Callable[[List], Any], id="callable-list-int-to-any"),
],
)
def test_callable_compatibility(sender_type, receiver_type):
assert _types_are_compatible(sender_type, receiver_type)
assert not _types_are_compatible(receiver_type, sender_type)


@pytest.mark.parametrize(
"sender_type,receiver_type",
[
pytest.param(
Callable[[Callable[[int], str]], List[str]], Callable, id="callable-with-nested-types-to-bare-callable"
),
pytest.param(
Callable[[Callable[[int], str]], List[str]],
Callable[[Callable[[Any], Any]], Any],
id="double-nested-callable",
),
pytest.param(
Callable[[Callable[[Callable[[int], str]], bool]], str],
Callable[[Callable[[Callable[[Any], Any]], Any]], Any],
id="triple-nested-callable",
),
],
)
def test_nested_callable_compatibility(sender_type, receiver_type):
assert _types_are_compatible(sender_type, receiver_type)
# Bare callable container types should not be compatible with their typed counterparts
assert not _types_are_compatible(receiver_type, sender_type)


@pytest.mark.parametrize(
"sender_type,receiver_type",
[
pytest.param(
Callable[[int, str], bool], Callable[[int, str], List], id="callable-to-callable-with-different-return-type"
),
pytest.param(
Callable[[int, int], bool], Callable[[int, str], bool], id="callable-to-callable-with-different-args-type"
),
pytest.param(
Callable[[int, str], bool], Callable[[int, str, int], bool], id="callable-to-callable-with-different-args"
),
pytest.param(Callable[[int, str], bool], Callable[[int], bool], id="callable-to-callable-with-fewer-args"),
],
)
def test_always_incompatible_callable_types(sender_type, receiver_type):
assert not _types_are_compatible(sender_type, receiver_type)
assert not _types_are_compatible(receiver_type, sender_type)
Loading