Skip to content

Commit f68539e

Browse files
Support values of inner docs given as AttrDict instances (#3080)
* Support values of inner docs given as AttrDict instances * one more unit test
1 parent 5a8e2a7 commit f68539e

File tree

5 files changed

+87
-5
lines changed

5 files changed

+87
-5
lines changed

elasticsearch/dsl/field.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,11 @@ def _serialize(
572572
if isinstance(data, collections.abc.Mapping):
573573
return data
574574

575-
return data.to_dict(skip_empty=skip_empty)
575+
try:
576+
return data.to_dict(skip_empty=skip_empty)
577+
except TypeError:
578+
# this would only happen if an AttrDict was given instead of an InnerDoc
579+
return data.to_dict()
576580

577581
def clean(self, data: Any) -> Any:
578582
data = super().clean(data)

test_elasticsearch/test_dsl/test_field.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@
2424
from dateutil import tz
2525

2626
from elasticsearch import dsl
27-
from elasticsearch.dsl import InnerDoc, Range, ValidationException, field
27+
from elasticsearch.dsl import (
28+
AttrDict,
29+
AttrList,
30+
InnerDoc,
31+
Range,
32+
ValidationException,
33+
field,
34+
)
2835

2936

3037
def test_date_range_deserialization() -> None:
@@ -235,6 +242,33 @@ class Inner(InnerDoc):
235242
field.Object(doc_class=Inner, dynamic=False)
236243

237244

245+
def test_dynamic_object() -> None:
246+
f = field.Object(dynamic=True)
247+
assert f.deserialize({"a": "b"}).to_dict() == {"a": "b"}
248+
assert f.deserialize(AttrDict({"a": "b"})).to_dict() == {"a": "b"}
249+
assert f.serialize({"a": "b"}) == {"a": "b"}
250+
assert f.serialize(AttrDict({"a": "b"})) == {"a": "b"}
251+
252+
253+
def test_dynamic_nested() -> None:
254+
f = field.Nested(dynamic=True)
255+
assert f.deserialize([{"a": "b"}, {"c": "d"}]) == [{"a": "b"}, {"c": "d"}]
256+
assert f.deserialize([AttrDict({"a": "b"}), {"c": "d"}]) == [
257+
{"a": "b"},
258+
{"c": "d"},
259+
]
260+
assert f.deserialize(AttrList([AttrDict({"a": "b"}), {"c": "d"}])) == [
261+
{"a": "b"},
262+
{"c": "d"},
263+
]
264+
assert f.serialize([{"a": "b"}, {"c": "d"}]) == [{"a": "b"}, {"c": "d"}]
265+
assert f.serialize([AttrDict({"a": "b"}), {"c": "d"}]) == [{"a": "b"}, {"c": "d"}]
266+
assert f.serialize(AttrList([AttrDict({"a": "b"}), {"c": "d"}])) == [
267+
{"a": "b"},
268+
{"c": "d"},
269+
]
270+
271+
238272
def test_all_fields_exported() -> None:
239273
"""Make sure that all the generated field classes are exported at the top-level"""
240274
fields = [

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from elasticsearch.dsl import (
3434
AsyncDocument,
3535
AsyncSearch,
36+
AttrDict,
3637
Binary,
3738
Boolean,
3839
Date,
@@ -627,13 +628,17 @@ async def test_can_save_to_different_index(
627628

628629

629630
@pytest.mark.asyncio
631+
@pytest.mark.parametrize("validate", (True, False))
630632
async def test_save_without_skip_empty_will_include_empty_fields(
631633
async_write_client: AsyncElasticsearch,
634+
validate: bool,
632635
) -> None:
633636
test_repo = Repository(
634637
field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42}
635638
)
636-
assert await test_repo.save(index="test-document", skip_empty=False)
639+
assert await test_repo.save(
640+
index="test-document", skip_empty=False, validate=validate
641+
)
637642

638643
assert_doc_equals(
639644
{
@@ -650,6 +655,23 @@ async def test_save_without_skip_empty_will_include_empty_fields(
650655
await async_write_client.get(index="test-document", id=42),
651656
)
652657

658+
test_repo = Repository(owner=AttrDict({"name": None}), meta={"id": 43})
659+
assert await test_repo.save(
660+
index="test-document", skip_empty=False, validate=validate
661+
)
662+
663+
assert_doc_equals(
664+
{
665+
"found": True,
666+
"_index": "test-document",
667+
"_id": "43",
668+
"_source": {
669+
"owner": {"name": None},
670+
},
671+
},
672+
await async_write_client.get(index="test-document", id=43),
673+
)
674+
653675

654676
@pytest.mark.asyncio
655677
async def test_delete(async_write_client: AsyncElasticsearch) -> None:

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
3333
from elasticsearch.dsl import (
34+
AttrDict,
3435
Binary,
3536
Boolean,
3637
Date,
@@ -621,13 +622,15 @@ def test_can_save_to_different_index(
621622

622623

623624
@pytest.mark.sync
625+
@pytest.mark.parametrize("validate", (True, False))
624626
def test_save_without_skip_empty_will_include_empty_fields(
625627
write_client: Elasticsearch,
628+
validate: bool,
626629
) -> None:
627630
test_repo = Repository(
628631
field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42}
629632
)
630-
assert test_repo.save(index="test-document", skip_empty=False)
633+
assert test_repo.save(index="test-document", skip_empty=False, validate=validate)
631634

632635
assert_doc_equals(
633636
{
@@ -644,6 +647,21 @@ def test_save_without_skip_empty_will_include_empty_fields(
644647
write_client.get(index="test-document", id=42),
645648
)
646649

650+
test_repo = Repository(owner=AttrDict({"name": None}), meta={"id": 43})
651+
assert test_repo.save(index="test-document", skip_empty=False, validate=validate)
652+
653+
assert_doc_equals(
654+
{
655+
"found": True,
656+
"_index": "test-document",
657+
"_id": "43",
658+
"_source": {
659+
"owner": {"name": None},
660+
},
661+
},
662+
write_client.get(index="test-document", id=43),
663+
)
664+
647665

648666
@pytest.mark.sync
649667
def test_delete(write_client: Elasticsearch) -> None:

utils/templates/field.py.tpl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,11 @@ class {{ k.name }}({{ k.parent }}):
334334
if isinstance(data, collections.abc.Mapping):
335335
return data
336336

337-
return data.to_dict(skip_empty=skip_empty)
337+
try:
338+
return data.to_dict(skip_empty=skip_empty)
339+
except TypeError:
340+
# this would only happen if an AttrDict was given instead of an InnerDoc
341+
return data.to_dict()
338342

339343
def clean(self, data: Any) -> Any:
340344
data = super().clean(data)

0 commit comments

Comments
 (0)