Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions haystack/core/super_component/super_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,20 @@ def _resolve_input_types_from_mapping(
aggregated_inputs[wrapper_input_name]["default"] = _delegate_default
continue

if not _is_compatible(existing_socket_info["type"], socket_info["type"]):
is_compatible, common_type = _is_compatible(existing_socket_info["type"], socket_info["type"])

if not is_compatible:
raise InvalidMappingTypeError(
f"Type conflict for input '{socket_name}' from component '{comp_name}'. "
f"Existing type: {existing_socket_info['type']}, new type: {socket_info['type']}."
)

# Use the common type for the aggregated input
aggregated_inputs[wrapper_input_name]["type"] = common_type

# If any socket requires mandatory inputs then the aggregated input is also considered mandatory.
# So we use the type of the mandatory input and remove the default value if it exists.
if socket_info["is_mandatory"]:
aggregated_inputs[wrapper_input_name]["type"] = socket_info["type"]
aggregated_inputs[wrapper_input_name].pop("default", None)

return aggregated_inputs
Expand Down
128 changes: 109 additions & 19 deletions haystack/core/super_component/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated, Any, TypeVar, Union, get_args, get_origin
from typing import Annotated, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast, get_args, get_origin

from haystack.core.component.types import HAYSTACK_GREEDY_VARIADIC_ANNOTATION, HAYSTACK_VARIADIC_ANNOTATION

Expand All @@ -14,33 +14,39 @@ class _delegate_default:
T = TypeVar("T")


def _is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> bool:
def _is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> Tuple[bool, Optional[T]]:
"""
Check if two types are compatible (bidirectional/symmetric check).

:param type1: First type to compare
:param type2: Second type to compare
:param unwrap_nested: If True, recursively unwraps nested Optional and Variadic types.
If False, only unwraps at the top level.
:return: True if types are compatible, False otherwise
:return: Tuple of (True if types are compatible, common type if compatible)
"""
type1_unwrapped = _unwrap_all(type1, recursive=unwrap_nested)
type2_unwrapped = _unwrap_all(type2, recursive=unwrap_nested)

return _types_are_compatible(type1_unwrapped, type2_unwrapped)


def _types_are_compatible(type1: T, type2: T) -> bool:
def _types_are_compatible(type1: T, type2: T) -> Tuple[bool, Optional[T]]:
"""
Core type compatibility check implementing symmetric matching.

:param type1: First unwrapped type to compare
:param type2: Second unwrapped type to compare
:return: True if types are compatible, False otherwise
"""
# Handle Any type and direct equality
if type1 is Any or type2 is Any or type1 == type2:
return True
# Handle Any type
if type1 is Any:
return True, _convert_to_typing_type(type2)
if type2 is Any:
return True, _convert_to_typing_type(type1)

# Direct equality
if type1 == type2:
return True, _convert_to_typing_type(type1)

type1_origin = get_origin(type1)
type2_origin = get_origin(type2)
Expand All @@ -53,34 +59,84 @@ def _types_are_compatible(type1: T, type2: T) -> bool:
return _check_non_union_compatibility(type1, type2, type1_origin, type2_origin)


def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> Tuple[bool, Optional[T]]:
"""Handle all Union type compatibility cases."""
if type1_origin is Union and type2_origin is not Union:
return any(_types_are_compatible(union_arg, type2) for union_arg in get_args(type1))
if type2_origin is Union and type1_origin is not Union:
return any(_types_are_compatible(type1, union_arg) for union_arg in get_args(type2))
# Both are Union types. Check all type combinations are compatible.
return any(any(_types_are_compatible(arg1, arg2) for arg2 in get_args(type2)) for arg1 in get_args(type1))

# Find all compatible types from the union
compatible_types = []
for union_arg in get_args(type1):
is_compat, common = _types_are_compatible(union_arg, type2)
if is_compat and common is not None:
compatible_types.append(common)
if compatible_types:
# The constructed Union or single type must be cast to Optional[T]
# to satisfy mypy, as T is specific to this function's call context.
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
return True, cast(Optional[T], result_type)
return False, None

def _check_non_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
if type2_origin is Union and type1_origin is not Union:
# Find all compatible types from the union
compatible_types = []
for union_arg in get_args(type2):
is_compat, common = _types_are_compatible(type1, union_arg)
if is_compat and common is not None:
compatible_types.append(common)
if compatible_types:
# The constructed Union or single type must be cast to Optional[T]
# to satisfy mypy, as T is specific to this function's call context.
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
return True, cast(Optional[T], result_type)
return False, None

# Both are Union types
compatible_types = []
for arg1 in get_args(type1):
for arg2 in get_args(type2):
is_compat, common = _types_are_compatible(arg1, arg2)
if is_compat and common is not None:
compatible_types.append(common)

if compatible_types:
# The constructed Union or single type must be cast to Optional[T]
# to satisfy mypy, as T is specific to this function's call context.
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
return True, cast(Optional[T], result_type)
return False, None


def _check_non_union_compatibility(
type1: T, type2: T, type1_origin: Any, type2_origin: Any
) -> Tuple[bool, Optional[T]]:
"""Handle non-Union type compatibility cases."""
# If no origin, compare types directly
if not type1_origin and not type2_origin:
return type1 == type2
if type1 == type2:
return True, type1
return False, None

# Both must have origins and they must be equal
if not (type1_origin and type2_origin and type1_origin == type2_origin):
return False
return False, None

# Compare generic type arguments
type1_args = get_args(type1)
type2_args = get_args(type2)

if len(type1_args) != len(type2_args):
return False
return False, None

# Check if all arguments are compatible
common_args = []
for t1_arg, t2_arg in zip(type1_args, type2_args):
is_compat, common = _types_are_compatible(t1_arg, t2_arg)
if not is_compat:
return False, None
common_args.append(common)

return all(_types_are_compatible(t1_arg, t2_arg) for t1_arg, t2_arg in zip(type1_args, type2_args))
# Reconstruct the type with common arguments
typing_type = _convert_to_typing_type(type1_origin)
return True, cast(Optional[T], typing_type[tuple(common_args)])


def _unwrap_all(t: T, recursive: bool) -> T:
Expand Down Expand Up @@ -167,3 +223,37 @@ def _unwrap_optionals(t: T, recursive: bool) -> T:
if recursive:
return _unwrap_all(result, recursive) # type: ignore
return result # type: ignore


def _convert_to_typing_type(t: Any) -> Any:
"""
Convert built-in Python types to their typing equivalents.

:param t: Type to convert
:return: The type using typing module types
"""
origin = get_origin(t)
args = get_args(t)

# Mapping of built-in types to their typing equivalents
type_converters = {
list: lambda: List if not args else List[Any],
dict: lambda: Dict if not args else Dict[Any, Any],
set: lambda: Set if not args else Set[Any],
tuple: lambda: Tuple if not args else Tuple[Any, ...],
}

# Recursive argument handling
if origin in type_converters:
result = type_converters[origin]()
if args:
if origin == list:
return List[_convert_to_typing_type(args[0])] # type: ignore
if origin == dict:
return Dict[_convert_to_typing_type(args[0]), _convert_to_typing_type(args[1])] # type: ignore
if origin == set:
return Set[_convert_to_typing_type(args[0])] # type: ignore
if origin == tuple:
return Tuple[tuple(_convert_to_typing_type(arg) for arg in args)]
return result
return t
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Enhance SuperComponent's type compatibility check to return the detected common type between two input types.
69 changes: 68 additions & 1 deletion test/core/super_component/test_super_component.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List
from typing import Any, List, Optional, Union

import pytest
from haystack import Document, SuperComponent, Pipeline, AsyncPipeline, component, super_component
Expand Down Expand Up @@ -366,3 +366,70 @@ def test_draw_with_default_parameters(self, mock_draw, sample_super_component, t

sample_super_component.draw(path=path)
mock_draw.assert_called_once_with(path=path, server_url="https://mermaid.ink", params=None, timeout=30)

def test_input_types_reconciliation(self):
"""Test that input types are properly reconciled when they are compatible but not identical."""

@component
class TypeTestComponent:
@component.output_types(result_int=int, result_any=Any)
def run(self, input_int: int, input_any: Any):
return {"result_int": input_int, "result_any": input_any}

pipeline = Pipeline()
pipeline.add_component("test1", TypeTestComponent())
pipeline.add_component("test2", TypeTestComponent())

input_mapping = {"number": ["test1.input_int", "test2.input_any"]}
output_mapping = {"test2.result_int": "result_int"}
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)

input_sockets = wrapper.__haystack_input__._sockets_dict
assert "number" in input_sockets
assert input_sockets["number"].type == int

def test_union_type_reconciliation(self):
"""Test that Union types are properly reconciled when creating a SuperComponent."""

@component
class UnionTypeComponent1:
@component.output_types(result=Union[int, str])
def run(self, input: Union[int, str]):
return {"result": input}

@component
class UnionTypeComponent2:
@component.output_types(result=Union[float, str])
def run(self, input: Union[float, str]):
return {"result": input}

pipeline = Pipeline()
pipeline.add_component("test1", UnionTypeComponent1())
pipeline.add_component("test2", UnionTypeComponent2())

input_mapping = {"data": ["test1.input", "test2.input"]}
output_mapping = {"test2.result": "result"}
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)

input_sockets = wrapper.__haystack_input__._sockets_dict
assert "data" in input_sockets
assert input_sockets["data"].type == Union[str]

def test_input_types_with_any(self):
"""Test that Any type is properly handled when reconciling types."""

@component
class AnyTypeComponent:
@component.output_types(result=str)
def run(self, specific: str, generic: Any):
return {"result": specific}

pipeline = Pipeline()
pipeline.add_component("test", AnyTypeComponent())

input_mapping = {"text": ["test.specific", "test.generic"]}
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping)

input_sockets = wrapper.__haystack_input__._sockets_dict
assert "text" in input_sockets
assert input_sockets["text"].type == str
Loading
Loading