diff --git a/mongomock/collection.py b/mongomock/collection.py index f18b71d8ec..d5868896ec 100644 --- a/mongomock/collection.py +++ b/mongomock/collection.py @@ -526,6 +526,51 @@ def _copy_field(self, obj, container): else: return copy.copy(obj) + def _extract_projection_operators(self, fields): + """Removes and returns fields with projection operators.""" + result = {} + allowed_projection_operators = set(['$elemMatch']) + for key, value in iteritems(fields): + if isinstance(value, dict): + for op in value: + if op not in allowed_projection_operators: + raise ValueError('Unsupported projection option: {}'.format(op)) + result[key] = value + + for key in result: + del fields[key] + + return result + + def _apply_projection_operators(self, ops, doc, doc_copy): + """Applies projection operators to copied document.""" + for field, op in iteritems(ops): + if field not in doc_copy: + if field in doc: + # field was not copied yet (since we are in include mode) + doc_copy[field] = doc[field] + else: + # field doesn't exist in original document, no work to do + continue + + if '$elemMatch' in op: + if isinstance(doc_copy[field], list): + # find the first item that matches + matched = False + for item in doc_copy[field]: + if filter_applies(op['$elemMatch'], item): + matched = True + doc_copy[field] = [item] + break + + # nothing have matched + if not matched: + del doc_copy[field] + + else: + # remove the field since there is nothing to iterate + del doc_copy[field] + def _copy_only_fields(self, doc, fields, container): """Copy only the specified fields.""" @@ -541,6 +586,9 @@ def _copy_only_fields(self, doc, fields, container): # value out and hang on to it until later id_value = fields.pop('_id', 1) + # filter out fields with projection operators, we will take care of them later + projection_operators = self._extract_projection_operators(fields) + # other than the _id field, all fields must be either includes or # excludes, this can evaluate to 0 if len(set(list(fields.values()))) > 1: @@ -595,6 +643,12 @@ def _copy_only_fields(self, doc, fields, container): doc_copy['_id'] = doc['_id'] fields['_id'] = id_value # put _id back in fields + + # time to apply the projection operators and put back their fields + self._apply_projection_operators(projection_operators, doc, doc_copy) + for field, op in iteritems(projection_operators): + fields[field] = op + return doc_copy def _update_document_fields(self, doc, fields, updater): diff --git a/tests/test__mongomock.py b/tests/test__mongomock.py index e572e08240..30216e1b21 100644 --- a/tests/test__mongomock.py +++ b/tests/test__mongomock.py @@ -252,6 +252,31 @@ def test__find_by_attributes_return_fields(self): # test no _id, otherProp:1 self.cmp.compare.find({"_id": id1}, {"someOtherProp": 1}) + def test__find_by_attributes_return_fields_elemMatch(self): + id = ObjectId() + self.cmp.do.insert({ + '_id': id, + 'owns': [ + {'type': 'hat', 'color': 'black'}, + {'type': 'hat', 'color': 'green'}, + {'type': 't-shirt', 'color': 'black', 'size': 'small'}, + {'type': 't-shirt', 'color': 'black'}, + {'type': 't-shirt', 'color': 'white'} + ], + 'hat': 'red' + }) + elem = {'$elemMatch': {'type': 't-shirt', 'color': 'black'}} + # test filtering on array field only + self.cmp.compare.find({'_id': id}, {'owns': elem}) + # test filtering on array field with inclusion + self.cmp.compare.find({'_id': id}, {'owns': elem, 'hat': 1}) + # test filtering on array field with exclusion + self.cmp.compare.find({'_id': id}, {'owns': elem, 'hat': 0}) + # test filtering on non array field + self.cmp.compare.find({'_id': id}, {'hat': elem}) + # test no match + self.cmp.compare.find({'_id': id}, {'owns': {'$elemMatch': {'type': 'cap'}}}) + def test__find_by_dotted_attributes(self): """Test seaching with dot notation.""" green_bowler = {