Skip to content

Commit

Permalink
Recursively call .to_dict() on objects in Search.extras/**kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Dec 2, 2020
1 parent 5487df0 commit 99b787c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 7 deletions.
6 changes: 3 additions & 3 deletions elasticsearch_dsl/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .exceptions import IllegalOperation
from .query import Bool, Q
from .response import Hit, Response
from .utils import AttrDict, DslBase
from .utils import AttrDict, DslBase, recursive_to_dict


class QueryProxy(object):
Expand Down Expand Up @@ -668,7 +668,7 @@ def to_dict(self, count=False, **kwargs):
if self._sort:
d["sort"] = self._sort

d.update(self._extra)
d.update(recursive_to_dict(self._extra))

if self._source not in (None, {}):
d["_source"] = self._source
Expand All @@ -683,7 +683,7 @@ def to_dict(self, count=False, **kwargs):
if self._script_fields:
d["script_fields"] = self._script_fields

d.update(kwargs)
d.update(recursive_to_dict(kwargs))
return d

def count(self):
Expand Down
6 changes: 3 additions & 3 deletions elasticsearch_dsl/update_by_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .query import Bool, Q
from .response import UpdateByQueryResponse
from .search import ProxyDescriptor, QueryProxy, Request
from .utils import recursive_to_dict


class UpdateByQuery(Request):
Expand Down Expand Up @@ -141,9 +142,8 @@ def to_dict(self, **kwargs):
if self._script:
d["script"] = self._script

d.update(self._extra)

d.update(kwargs)
d.update(recursive_to_dict(self._extra))
d.update(recursive_to_dict(kwargs))
return d

def execute(self):
Expand Down
16 changes: 16 additions & 0 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,19 @@ def merge(data, new_data, raise_on_conflict=False):
raise ValueError("Incompatible data for key %r, cannot be merged." % key)
else:
data[key] = value


def recursive_to_dict(data):
"""Recursively transform objects that potentially have .to_dict()
into dictionary literals by traversing AttrList, AttrDict, list,
tuple, and Mapping types.
"""
if isinstance(data, AttrList):
data = list(data._l_)
elif hasattr(data, "to_dict"):
data = data.to_dict()
if isinstance(data, (list, tuple)):
return type(data)(recursive_to_dict(inner) for inner in data)
elif isinstance(data, collections_abc.Mapping):
return {key: recursive_to_dict(val) for key, val in data.items()}
return data
62 changes: 62 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,65 @@ def test_update_from_dict():
"indices_boost": [{"important-documents": 2}],
"_source": ["id", "name"],
} == s.to_dict()


def test_rescore_query_to_dict():
s = search.Search(index="index-name")

positive_query = Q(
"function_score",
query=Q("term", tags="a"),
script_score={"script": "_score * 1"},
)

negative_query = Q(
"function_score",
query=Q("term", tags="b"),
script_score={"script": "_score * -100"},
)

s = s.query(positive_query)
s = s.extra(
rescore={"window_size": 100, "query": {"rescore_query": negative_query}}
)
assert s.to_dict() == {
"query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
},
"rescore": {
"window_size": 100,
"query": {
"rescore_query": {
"function_score": {
"query": {"term": {"tags": "b"}},
"functions": [{"script_score": {"script": "_score * -100"}}],
}
}
},
},
}

assert s.to_dict(
rescore={"window_size": 10, "query": {"rescore_query": positive_query}}
) == {
"query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
},
"rescore": {
"window_size": 10,
"query": {
"rescore_query": {
"function_score": {
"query": {"term": {"tags": "a"}},
"functions": [{"script_score": {"script": "_score * 1"}}],
}
}
},
},
}
3 changes: 3 additions & 0 deletions tests/test_update_by_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def test_ubq_to_dict():
ubq = UpdateByQuery(extra={"size": 5})
assert {"size": 5} == ubq.to_dict()

ubq = UpdateByQuery(extra={"extra_q": Q("term", category="conference")})
assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict()


def test_complex_example():
ubq = UpdateByQuery()
Expand Down
8 changes: 7 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytest import raises

from elasticsearch_dsl import serializer, utils
from elasticsearch_dsl import Q, serializer, utils


def test_attrdict_pickle():
Expand Down Expand Up @@ -94,3 +94,9 @@ def to_dict(self):
return 42

assert serializer.serializer.dumps(MyClass()) == "42"


def test_recursive_to_dict():
assert utils.recursive_to_dict({"k": [1, (1.0, {"v": Q("match", key="val")})]}) == {
"k": [1, (1.0, {"v": {"match": {"key": "val"}}})]
}

0 comments on commit 99b787c

Please sign in to comment.