Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4d2d438
Add serialization to State
Amnah199 May 5, 2025
effb921
Add release notes
Amnah199 May 5, 2025
5e0f50a
Deprecate State in dataclasses
Amnah199 May 6, 2025
2de0e20
Fix tests
Amnah199 May 6, 2025
ebd7097
Merge branch 'main' into add-ser-for-state-dataclass
Amnah199 May 6, 2025
c840fb0
Remove state_utils test
Amnah199 May 6, 2025
7f0eae6
Merge branch 'add-ser-for-state-dataclass' of https://github.com/deep…
Amnah199 May 6, 2025
859914f
Fix linting
Amnah199 May 6, 2025
c8708c9
Fix formating
Amnah199 May 6, 2025
c25fcf2
Update tests and remove old state utils
Amnah199 May 6, 2025
0190e04
Update agents test
Amnah199 May 6, 2025
1e579c3
Update deserilaization per review
Amnah199 May 13, 2025
f64bd72
Linting
Amnah199 May 13, 2025
22aba3a
Add tests for edge case (custom class types)
Amnah199 May 13, 2025
9f48f25
Fix type serialization
Amnah199 May 15, 2025
9e4071f
PR comments
Amnah199 May 15, 2025
4d0d89c
Move State to agents
Amnah199 May 16, 2025
a25a0a8
Merge branch 'main' of https://github.com/deepset-ai/haystack into ad…
Amnah199 May 16, 2025
0c4204f
Fix tests
Amnah199 May 16, 2025
ff0a1c6
Update utils init
Amnah199 May 16, 2025
5f892e2
Improve seriliaztion/deser
Amnah199 May 19, 2025
3989eca
Update the release notes
Amnah199 May 19, 2025
7390b3c
Minor fix in docstrings
Amnah199 May 19, 2025
16007e5
PR comments
Amnah199 May 20, 2025
018325c
Add deprecation warnign for state utils
Amnah199 May 20, 2025
0d2d2ef
Recreate the serialization methods to use schema
Amnah199 May 21, 2025
8b3c2c4
Update key names
Amnah199 May 22, 2025
d607112
Make serialization methods private
Amnah199 May 23, 2025
7682ee0
Merge branch 'main' into add-ser-for-state-dataclass
Amnah199 May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion haystack/components/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from lazy_imports import LazyImporter

_import_structure = {"agent": ["Agent"]}
_import_structure = {"agent": ["Agent"], "state": ["State"]}

if TYPE_CHECKING:
from .agent import Agent
from .state import State

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
5 changes: 3 additions & 2 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
from haystack.core.serialization import component_to_dict
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
from haystack.dataclasses.state_utils import merge_lists
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.deserialization import deserialize_chatgenerator_inplace

from .state.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
from .state.state_utils import merge_lists

logger = logging.getLogger(__name__)


Expand Down
8 changes: 8 additions & 0 deletions haystack/components/agents/state/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from .state import State
from .state_utils import merge_lists, replace_values

__all__ = ["State", "merge_lists", "replace_values"]
179 changes: 179 additions & 0 deletions haystack/components/agents/state/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

from haystack.dataclasses import ChatMessage
from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.type_serialization import deserialize_type, serialize_type

from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values


def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a schema dictionary to a serializable format.

Converts each parameter's type and optional handler function into a serializable
format using type and callable serialization utilities.

:param schema: Dictionary mapping parameter names to their type and handler configs
:returns: Dictionary with serialized type and handler information
"""
serialized_schema = {}
for param, config in schema.items():
serialized_schema[param] = {"type": serialize_type(config["type"])}
if config.get("handler"):
serialized_schema[param]["handler"] = serialize_callable(config["handler"])

return serialized_schema


def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a serialized schema dictionary back to its original format.

Deserializes the type and optional handler function for each parameter from their
serialized format back into Python types and callables.

:param schema: Dictionary containing serialized schema information
:returns: Dictionary with deserialized type and handler configurations
"""
deserialized_schema = {}
for param, config in schema.items():
deserialized_schema[param] = {"type": deserialize_type(config["type"])}

if config.get("handler"):
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])

return deserialized_schema


def _validate_schema(schema: Dict[str, Any]) -> None:
"""
Validate that a schema dictionary meets all required constraints.

Checks that each parameter definition has a valid type field and that any handler
specified is a callable function.

:param schema: Dictionary mapping parameter names to their type and handler configs
:raises ValueError: If schema validation fails due to missing or invalid fields
"""
for param, definition in schema.items():
if "type" not in definition:
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
if not _is_valid_type(definition["type"]):
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
if definition.get("handler") is not None and not callable(definition["handler"]):
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
if param == "messages" and definition["type"] is not List[ChatMessage]:
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")


class State:
"""
A class that wraps a StateSchema and maintains an internal _data dictionary.

Each schema entry has:
"parameter_name": {
"type": SomeType,
"handler": Optional[Callable[[Any, Any], Any]]
}
"""

def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
"""
Initialize a State object with a schema and optional data.

:param schema: Dictionary mapping parameter names to their type and handler configs.
Type must be a valid Python type, and handler must be a callable function or None.
If handler is None, the default handler for the type will be used. The default handlers are:
- For list types: `haystack.agents.state.state_utils.merge_lists`
- For all other types: `haystack.agents.state.state_utils.replace_values`
:param data: Optional dictionary of initial data to populate the state
"""
_validate_schema(schema)
self.schema = deepcopy(schema)
if self.schema.get("messages") is None:
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
self._data = data or {}

# Set default handlers if not provided in schema
for definition in self.schema.values():
# Skip if handler is already defined and not None
if definition.get("handler") is not None:
continue
# Set default handler based on type
if _is_list_type(definition["type"]):
definition["handler"] = merge_lists
else:
definition["handler"] = replace_values

def get(self, key: str, default: Any = None) -> Any:
"""
Retrieve a value from the state by key.

:param key: Key to look up in the state
:param default: Value to return if key is not found
:returns: Value associated with key or default if not found
"""
return deepcopy(self._data.get(key, default))

def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
"""
Set or merge a value in the state according to schema rules.

Value is merged or overwritten according to these rules:
- if handler_override is given, use that
- else use the handler defined in the schema for 'key'

:param key: Key to store the value under
:param value: Value to store or merge
:param handler_override: Optional function to override the default merge behavior
"""
# If key not in schema, we throw an error
definition = self.schema.get(key, None)
if definition is None:
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")

# Get current value from state and apply handler
current_value = self._data.get(key, None)
handler = handler_override or definition["handler"]
self._data[key] = handler(current_value, value)

@property
def data(self):
"""
All current data of the state.
"""
return self._data

def has(self, key: str) -> bool:
"""
Check if a key exists in the state.

:param key: Key to check for existence
:returns: True if key exists in state, False otherwise
"""
return key in self._data

def to_dict(self):
"""
Convert the State object to a dictionary.
"""
serialized = {}
serialized["schema"] = _schema_to_dict(self.schema)
serialized["data"] = _serialize_value_with_schema(self._data)
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""
Convert a dictionary back to a State object.
"""
schema = _schema_from_dict(data.get("schema", {}))
deserialized_data = _deserialize_value_with_schema(data.get("data", {}))
return State(schema, deserialized_data)
77 changes: 77 additions & 0 deletions haystack/components/agents/state/state_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Any, List, TypeVar, Union, get_origin

T = TypeVar("T")


def _is_valid_type(obj: Any) -> bool:
"""
Check if an object is a valid type annotation.

Valid types include:
- Normal classes (str, dict, CustomClass)
- Generic types (List[str], Dict[str, int])
- Union types (Union[str, int], Optional[str])

:param obj: The object to check
:return: True if the object is a valid type annotation, False otherwise

Example usage:
>>> _is_valid_type(str)
True
>>> _is_valid_type(List[int])
True
>>> _is_valid_type(Union[str, int])
True
>>> _is_valid_type(42)
False
"""
# Handle Union types (including Optional)
if hasattr(obj, "__origin__") and obj.__origin__ is Union:
return True

# Handle normal classes and generic types
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}


def _is_list_type(type_hint: Any) -> bool:
"""
Check if a type hint represents a list type.

:param type_hint: The type hint to check
:return: True if the type hint represents a list, False otherwise
"""
return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list)


def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.

If either `current` or `new` is not already a list, it is converted into one.
The function ensures that both inputs are treated as lists and concatenates them.

If `current` is None, it is treated as an empty list.

:param current: The existing value(s), either a single item or a list.
:param new: The new value(s) to merge, either a single item or a list.
:return: A list containing elements from both `current` and `new`.
"""
current_list = [] if current is None else current if isinstance(current, list) else [current]
new_list = new if isinstance(new, list) else [new]
return current_list + new_list


def replace_values(current: Any, new: Any) -> Any:
"""
Replace the `current` value with the `new` value.

:param current: The existing value
:param new: The new value to replace
:return: The new value
"""
return new
3 changes: 2 additions & 1 deletion haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.agents import State
from haystack.core.component.sockets import Sockets
from haystack.dataclasses import ChatMessage, State, ToolCall
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.tools import (
ComponentTool,
Expand Down
Loading
Loading