Skip to content

Commit

Permalink
Fixing #68 - Now querying by sub-properties of embedded documents in …
Browse files Browse the repository at this point in the history
…list fields is supported
  • Loading branch information
heynemann committed Jun 29, 2014
1 parent da5ce2a commit 958ebd6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
5 changes: 4 additions & 1 deletion motorengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_field_by_db_name(cls, name):

@classmethod
def get_fields(cls, name, fields=None):
from motorengine import EmbeddedDocumentField
from motorengine import EmbeddedDocumentField, ListField
from motorengine.fields.dynamic_field import DynamicField

if fields is None:
Expand All @@ -290,6 +290,9 @@ def get_fields(cls, name, fields=None):
if isinstance(obj, (EmbeddedDocumentField, )):
obj.embedded_type.get_fields(".".join(field_values[1:]), fields=fields)

if isinstance(obj, (ListField, )):
obj.item_type.get_fields(".".join(field_values[1:]), fields=fields)

return fields


Expand Down
10 changes: 10 additions & 0 deletions motorengine/fields/list_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ def to_query(self, value):

def from_son(self, value):
return list(map(self._base_field.from_son, value))

@property
def item_type(self):
if hasattr(self._base_field, 'embedded_type'):
return self._base_field.embedded_type

if hasattr(self._base_field, 'reference_type'):
return self._base_field.reference_type

return self._base_field
6 changes: 5 additions & 1 deletion motorengine/query_builder/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def transform_query(document, **query):

def validate_fields(document, query):
from motorengine.fields.embedded_document_field import EmbeddedDocumentField
from motorengine.fields.list_field import ListField

for key, query in sorted(query.items()):
if '__' not in key:
Expand All @@ -97,7 +98,10 @@ def validate_fields(document, query):
fields = document.get_fields(field_reference_name)

is_none = (not fields) or (not all(fields))
if is_none or (not isinstance(fields[0], (EmbeddedDocumentField,)) and operator == ''):
is_embedded = isinstance(fields[0], (EmbeddedDocumentField,))
is_list = isinstance(fields[0], (ListField,))

if is_none or (not is_embedded and not is_list and operator == ''):
raise ValueError(
"Invalid filter '%s': Invalid operator (if this is a sub-property, "
"then it must be used in embedded document fields)." % key)
21 changes: 21 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,27 @@ class ElemMatchDocument(Document):

expect(loaded_document._id).to_equal(doc._id)

def test_can_query_by_elem_match_when_list_of_embedded(self):
class ElemMatchEmbeddedDocument(Document):
name = StringField()

class ElemMatchEmbeddedParentDocument(Document):
items = ListField(EmbeddedDocumentField(ElemMatchEmbeddedDocument))

self.drop_coll(ElemMatchEmbeddedDocument.__collection__)
self.drop_coll(ElemMatchEmbeddedParentDocument.__collection__)

ElemMatchEmbeddedParentDocument.objects.create(items=[ElemMatchEmbeddedDocument(name="a"), ElemMatchEmbeddedDocument(name="b")], callback=self.stop)
doc = self.wait()

ElemMatchEmbeddedParentDocument.objects.create(items=[ElemMatchEmbeddedDocument(name="c"), ElemMatchEmbeddedDocument(name="d")], callback=self.stop)
doc2 = self.wait()

ElemMatchEmbeddedParentDocument.objects.filter(items__name="b").find_all(callback=self.stop)
loaded_document = self.wait()

expect(loaded_document).to_length(1)

def test_raw_query(self):
class RawQueryEmbeddedDocument(Document):
name = StringField()
Expand Down

0 comments on commit 958ebd6

Please sign in to comment.