Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions src/mcp_agent/llm/augmented_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Expand Down Expand Up @@ -59,7 +60,36 @@
HUMAN_INPUT_TOOL_NAME = "__human_input__"


class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
def deep_merge(dict1: Dict[Any, Any], dict2: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Recursively merges `dict2` into `dict1` in place.

If a key exists in both dictionaries and their values are dictionaries,
the function merges them recursively. Otherwise, the value from `dict2`
overwrites or is added to `dict1`.

Args:
dict1 (Dict): The dictionary to be updated.
dict2 (Dict): The dictionary to merge into `dict1`.

Returns:
Dict: The updated `dict1`.
"""
for key in dict2:
if (
key in dict1
and isinstance(dict1[key], dict)
and isinstance(dict2[key], dict)
):
deep_merge(dict1[key], dict2[key])
else:
dict1[key] = dict2[key]
return dict1


class AugmentedLLM(
ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]
):
# Common parameter names used across providers
PARAM_MESSAGES = "messages"
PARAM_MODEL = "model"
Expand Down Expand Up @@ -357,8 +387,10 @@ def _merge_request_params(
) -> RequestParams:
"""Merge default and provided request parameters"""

merged = default_params.model_dump()
merged.update(provided_params.model_dump(exclude_unset=True))
merged = deep_merge(
default_params.model_dump(),
provided_params.model_dump(exclude_unset=True),
)
final_params = RequestParams(**merged)

return final_params
Expand Down
Loading