diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index e21fdd96621fe1..13b91270c70d8a 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -19,11 +19,9 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: for k, v in right.items(): if k not in merged: merged[k] = v - elif merged[k] is None and v: + elif v is not None and merged[k] is None: merged[k] = v - elif v is None: - continue - elif merged[k] == v: + elif v is None or merged[k] == v: continue elif type(merged[k]) != type(v): raise TypeError( diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index e6dedb2a5a86bf..a69ea2510b41fa 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -1,9 +1,12 @@ -from typing import Dict, Optional, Tuple, Type +import re +from contextlib import AbstractContextManager, nullcontext +from typing import Dict, Optional, Tuple, Type, Union from unittest.mock import patch import pytest from langchain_core.utils import check_package_version +from langchain_core.utils._merge import merge_dicts @pytest.mark.parametrize( @@ -28,3 +31,74 @@ def test_check_package_version( else: with pytest.raises(expected[0], match=expected[1]): check_package_version(package, **check_kwargs) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + ( + # Merge `None` and `1`. + ({"a": None}, {"a": 1}, {"a": 1}), + # Merge `1` and `None`. + ({"a": 1}, {"a": None}, {"a": 1}), + # Merge `None` and a value. + ({"a": None}, {"a": 0}, {"a": 0}), + ({"a": None}, {"a": "txt"}, {"a": "txt"}), + # Merge equal values. + ({"a": 1}, {"a": 1}, {"a": 1}), + ({"a": 1.5}, {"a": 1.5}, {"a": 1.5}), + ({"a": True}, {"a": True}, {"a": True}), + ({"a": False}, {"a": False}, {"a": False}), + ({"a": "txt"}, {"a": "txt"}, {"a": "txt"}), + ({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}), + ({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}), + # Merge strings. + ({"a": "one"}, {"a": "two"}, {"a": "onetwo"}), + # Merge dicts. + ({"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 1, "c": 2}}), + ( + {"function_call": {"arguments": None}}, + {"function_call": {"arguments": "{\n"}}, + {"function_call": {"arguments": "{\n"}}, + ), + # Merge lists. + ({"a": [1, 2]}, {"a": [3]}, {"a": [1, 2, 3]}), + ({"a": 1, "b": 2}, {"a": 1}, {"a": 1, "b": 2}), + ({"a": 1, "b": 2}, {"c": None}, {"a": 1, "b": 2, "c": None}), + # + # Invalid inputs. + # + ( + {"a": 1}, + {"a": "1"}, + pytest.raises( + TypeError, + match=re.escape( + 'additional_kwargs["a"] already exists in this message, ' + "but with a different type." + ), + ), + ), + ( + {"a": (1, 2)}, + {"a": (3,)}, + pytest.raises( + TypeError, + match=( + "Additional kwargs key a already exists in left dict and value " + "has unsupported type .+tuple.+." + ), + ), + ), + ), +) +def test_merge_dicts( + left: dict, right: dict, expected: Union[dict, AbstractContextManager] +) -> None: + if isinstance(expected, AbstractContextManager): + err = expected + else: + err = nullcontext() + + with err: + actual = merge_dicts(left, right) + assert actual == expected