diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 963be97..d702696 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Changelog ========= +- Support matches-style callbacks on non-dictionary objects that are compatible with ``pydash.get`` in functions like ``pydash.find``. + v5.0.2 (2021-07-15) ------------------- diff --git a/src/pydash/helpers.py b/src/pydash/helpers.py index 0cc3138..0467c9c 100644 --- a/src/pydash/helpers.py +++ b/src/pydash/helpers.py @@ -102,7 +102,7 @@ def iteriteratee(obj, iteratee=None, reverse=False): def iterator(obj): """Return iterative based on object type.""" - if isinstance(obj, dict): + if isinstance(obj, Mapping): return obj.items() elif hasattr(obj, "iteritems"): return obj.iteritems() # noqa: B301 diff --git a/src/pydash/predicates.py b/src/pydash/predicates.py index 16ceda3..dd7b57c 100644 --- a/src/pydash/predicates.py +++ b/src/pydash/predicates.py @@ -4,6 +4,7 @@ .. versionadded:: 2.0.0 """ +from collections.abc import Iterable, Mapping import datetime from itertools import islice import json @@ -13,7 +14,7 @@ import pydash as pyd -from .helpers import BUILTINS, NUMBER_TYPES, UNSET, callit, iterator +from .helpers import BUILTINS, NUMBER_TYPES, UNSET, base_get, callit, iterator __all__ = ( @@ -914,14 +915,7 @@ def cbk(obj_value, src_value): else: cbk = customizer - if ( - isinstance(obj, dict) - and isinstance(source, dict) - or isinstance(obj, list) - and isinstance(source, list) - or isinstance(obj, tuple) - and isinstance(source, tuple) - ): + if isinstance(source, (Mapping, Iterable)) and not isinstance(source, str): # Set equal to True if source is empty, otherwise, False and then allow deep comparison to # determine equality. equal = not source @@ -929,7 +923,8 @@ def cbk(obj_value, src_value): # Walk a/b to determine equality. for key, value in iterator(source): try: - equal = is_match_with(obj[key], value, cbk, _key=key, _obj=_obj, _source=_source) + obj_value = base_get(obj, key) + equal = is_match_with(obj_value, value, cbk, _key=key, _obj=_obj, _source=_source) except Exception: equal = False diff --git a/tests/test_collections.py b/tests/test_collections.py index ef397e8..e4e71fa 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,3 +1,4 @@ +from collections import namedtuple import math from operator import itemgetter @@ -138,6 +139,17 @@ def test_find(case, expected): assert _.find(*case) == expected +def test_find_class_object(): + obj = fixtures.Object(a=1, b=2) + assert _.find([None, {}, obj], {"b": 2}) == obj + + +def test_find_namedtuple(): + User = namedtuple("User", ["first_name", "last_name"]) + obj = User(first_name="Bob", last_name="Smith") + assert _.find([None, {}, obj], {"first_name": "Bob"}) == obj + + @parametrize( "case,expected", [(({"abc": 1, "xyz": 2, "c": 3}.values(), fixtures.Filter(lambda x: x < 2)), 1)],