diff --git a/pymodm/base/fields.py b/pymodm/base/fields.py index 02f6856..b807058 100644 --- a/pymodm/base/fields.py +++ b/pymodm/base/fields.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import bson from pymodm import validators from pymodm.common import ( @@ -220,6 +221,23 @@ def related_model(self): self.__related_model = self.__model return self.__related_model + def _model_to_document(self, value): + if isinstance(value, bson.SON): + # value has been already converted + return value + + if isinstance(value, self.related_model): + return value.to_son() + + if isinstance(value, dict): + # if value is a dict convert in to model + # so we can properly generate SON + return self.related_model.from_document(value).to_son() + + # we could not convert value to SON + raise ValidationError( + '%s is not a valid %s' % (value, self.related_model.__name__)) + class GeoJSONField(MongoBaseField): """Base class for GeoJSON fields.""" diff --git a/pymodm/base/models.py b/pymodm/base/models.py index 72b889f..5db329e 100644 --- a/pymodm/base/models.py +++ b/pymodm/base/models.py @@ -262,11 +262,16 @@ def from_document(cls, document): """ dct = validate_mapping('document', document) + doc_cls = cls cls_name = dct.get('_cls') if cls_name is not None: - cls = get_document(cls_name) + doc_cls = get_document(cls_name) + if not issubclass(doc_cls, cls): + raise TypeError('A document\'s _cls field must be ' + 'a subclass of the %s, but %s is not.' + % (doc_cls, cls)) - inst = cls() + inst = doc_cls() inst._set_attributes(dct) return inst diff --git a/pymodm/dereference.py b/pymodm/dereference.py index c4074f9..1ed1c7c 100644 --- a/pymodm/dereference.py +++ b/pymodm/dereference.py @@ -14,12 +14,10 @@ from collections import defaultdict, deque -from bson.dbref import DBRef - from pymodm.base.models import MongoModelBase from pymodm.connection import _get_db from pymodm.context_managers import no_auto_dereference -from pymodm.fields import ReferenceField +from pymodm.fields import ReferenceField, ListField, EmbeddedDocumentListField class _ObjectMap(dict): @@ -53,9 +51,8 @@ def __contains__(self, key): def _find_references_in_object(object, field, reference_map, fields=None): - if isinstance(object, DBRef): - reference_map[object.collection].append(object.id) - elif isinstance(field, ReferenceField) and not field._is_instance: + if (isinstance(field, ReferenceField) and + not isinstance(object, field.related_model)): collection_name = field.related_model._mongometa.collection_name reference_map[collection_name].append( field.related_model._mongometa.pk.to_mongo(object)) @@ -91,72 +88,79 @@ def _find_references(model_instance, reference_map, fields=None): def _resolve_references(database, reference_map): - document_map = _ObjectMap() + document_map = defaultdict(_ObjectMap) for collection_name in reference_map: collection = database[collection_name] query = {'_id': {'$in': reference_map[collection_name]}} documents = collection.find(query) for document in documents: - document_map[document['_id']] = document + document_map[collection_name][document['_id']] = document + return document_map -def _get_value(container, key): - if hasattr(container, '__getitem__'): - return container[key] - return getattr(container, key) - - -def _set_value(container, key, value): - if hasattr(container, '__setitem__'): - container[key] = value - else: - setattr(container, key, value) - - -def _set_or_recurse(object, document_map, path, key, value): - if not path: - if isinstance(value, DBRef) and value.id in document_map: - _set_value(object, key, document_map[value.id]) - # elif isinstance(value, ObjectId) and value in document_map: - elif value in document_map: - _set_value(object, key, document_map[value]) - else: - # path is empty, but value could be a list. - _attach_objects_in_path(value, document_map, path) - else: - _attach_objects_in_path(value, document_map, path) - - -def _attach_objects_in_path(container, document_map, path=None): - # Paths can't name indexes in a list, so just recurse on the list. - if isinstance(container, list): - for index, item in enumerate(container): - _set_or_recurse(container, document_map, path, index, item) - # Retrieve and recurse on the next field, if we have a path. - elif path: - part = path.popleft() - value = _get_value(container, part) - _set_or_recurse(container, document_map, path, part, value) - # Recurse on every field, if there's no path given. - elif hasattr(container, 'items') or isinstance(container, MongoModelBase): - # Container is a dict or a MongoModel of some kind. - # Both iterate their keys. - for key in container: - value = _get_value(container, key) - _set_or_recurse(container, document_map, path, key, value) - - -def _attach_objects(container, document_map, fields=None): +def _get_reference_document(document_map, collection_name, ref_id): + try: + return document_map[collection_name][ref_id] + except KeyError: + return None + + +def _attach_objects_in_path(container, document_map, fields, key, field): + try: + value = container[key] + except KeyError: + # there is no value for given key + return + + if (isinstance(field, ReferenceField) and + not isinstance(value, field.related_model)): + # value is reference id + meta = field.related_model._mongometa + container[key] = _get_reference_document(document_map, + meta.collection_name, + meta.pk.to_mongo(value)) + elif isinstance(field, ListField): + # value is list + for idx, item in enumerate(value): + _attach_objects_in_path(value, document_map, fields, + idx, field._field) + elif isinstance(field, EmbeddedDocumentListField): + # value is list of embedded models instances + for emb_model_inst in value: + _attach_objects(emb_model_inst, document_map, fields) + elif isinstance(value, MongoModelBase): + # value is embedded model instance or reference is + # already dereferenced + _attach_objects(value, document_map, fields) + + +def _attach_objects(model_instance, document_map, fields=None): + container = model_instance._data + field_names_map = {} + if fields: + for idx, field in enumerate(fields): + if field: + field_names_map[idx] = field.popleft() + field_names = set(field_names_map.values()) + + for field in model_instance._mongometa.get_fields(): + # Skip any fields we don't care about. + if fields and field.attname not in field_names: + continue + + _attach_objects_in_path(container, document_map, fields, + field.attname, field) + if fields: - for field in fields: - _attach_objects_in_path(container, document_map, field) - else: - _attach_objects_in_path(container, document_map) + # Restore parts of field names that we took off while scanning. + for field_idx, field_name in field_names_map.items(): + fields[field_idx].appendleft(field_name) def dereference(model_instance, fields=None): """Dereference ReferenceFields on a MongoModel instance. + This function is handy for dereferencing many fields at once and is more efficient than dereferencing one field at a time. @@ -178,11 +182,12 @@ def dereference(model_instance, fields=None): db = _get_db(model_instance._mongometa.connection_alias) # Resolve all references, one collection at a time. - # This will give us a mapping of id --> resolved object. + # This will give us a mapping of + # {collection_name --> {id --> resolved object}} document_map = _resolve_references(db, reference_map) # Traverse the object and attach resolved references where needed. - _attach_objects(model_instance._data, document_map, fields) + _attach_objects(model_instance, document_map, fields) return model_instance @@ -194,7 +199,7 @@ def dereference_id(model_class, model_id): - `model_class`: The class of a model to be dereferenced. - `model_id`: The id of the model to be dereferenced. """ - collection = model_class._mongometa.collection - document = collection.find_one(model_id) + meta = model_class._mongometa + document = meta.collection.find_one({'_id': meta.pk.to_mongo(model_id)}) if document: return model_class.from_document(document) diff --git a/pymodm/fields.py b/pymodm/fields.py index 9727522..38704e3 100644 --- a/pymodm/fields.py +++ b/pymodm/fields.py @@ -1063,7 +1063,7 @@ def to_python(self, value): return value def to_mongo(self, value): - return value.to_son() + return self._model_to_document(value) class EmbeddedDocumentListField(RelatedModelFieldsBase): @@ -1106,7 +1106,7 @@ def to_python(self, value): for item in value] def to_mongo(self, value): - return [doc.to_son() for doc in value] + return [self._model_to_document(doc) for doc in value] class ReferenceField(RelatedModelFieldsBase): @@ -1161,15 +1161,7 @@ def __init__(self, model, on_delete=DO_NOTHING, 'use MyModelClass.register_delete_rule instead.' % model) self._on_delete = on_delete - self._is_instance = False - - def validate_related_model(ref): - """Given a Model, verify that it's been saved first.""" - if isinstance(ref, self.related_model) and not ref.pk: - raise ValidationError( - 'Referenced documents must be saved to the database first.') - - self.validators.append(validate_related_model) + self.validators.append(validators.validator_for_func(self.to_mongo)) def contribute_to_class(self, cls, name): super(ReferenceField, self).contribute_to_class(cls, name) @@ -1182,14 +1174,20 @@ def contribute_to_class(self, cls, name): def to_python(self, value): if isinstance(value, dict): # Try to convert the value into our document type. - return self.related_model.from_document(value) - elif isinstance(value, self.related_model): + try: + return self.related_model.from_document(value) + except (ValueError, TypeError): + pass + + if isinstance(value, self.related_model): return value - elif self.model._mongometa._auto_dereference: + + if self.model._mongometa._auto_dereference: # Attempt to dereference the value as an id. dereference_id = _import('pymodm.dereference.dereference_id') return dereference_id(self.related_model, value) - return value + + return self.related_model._mongometa.pk.to_python(value) def to_mongo(self, value): if isinstance(value, self.related_model): @@ -1211,8 +1209,3 @@ def __get__(self, inst, owner): self.__set__(inst, python) return python return self - - def __set__(self, inst, value): - MongoModel = _import('pymodm.base.models.MongoModel') - super(ReferenceField, self).__set__(inst, value) - self._is_instance = isinstance(value, MongoModel) diff --git a/test/field_types/test_embedded_document_field.py b/test/field_types/test_embedded_document_field.py new file mode 100644 index 0000000..ca7888f --- /dev/null +++ b/test/field_types/test_embedded_document_field.py @@ -0,0 +1,63 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from bson import SON + +from pymodm import EmbeddedMongoModel +from pymodm.errors import ValidationError +from pymodm.fields import EmbeddedDocumentField, CharField + +from test.field_types import FieldTestCase + + +class EmbeddedDocument(EmbeddedMongoModel): + name = CharField() + + class Meta: + final = True + + +class EmbeddedDocumentFieldTestCase(FieldTestCase): + + field = EmbeddedDocumentField(EmbeddedDocument) + + def test_to_python(self): + value = self.field.to_python({'name': 'Bob'}) + self.assertIsInstance(value, EmbeddedDocument) + + doc = EmbeddedDocument(name='Bob') + value = self.field.to_python(doc) + self.assertIsInstance(value, EmbeddedDocument) + self.assertEqual(value, doc) + + def test_to_mongo(self): + doc = EmbeddedDocument(name='Bob') + value = self.field.to_mongo(doc) + self.assertIsInstance(value, SON) + self.assertEqual(value, SON({'name': 'Bob'})) + + son = value + value = self.field.to_mongo(son) + self.assertIsInstance(value, SON) + self.assertEqual(value, SON({'name': 'Bob'})) + + value = self.field.to_mongo({'name': 'Bob'}) + + self.assertIsInstance(value, SON) + self.assertEqual(value, SON({'name': 'Bob'})) + + def test_to_mongo_wrong_model(self): + with self.assertRaises(ValidationError) as cm: + self.field.to_mongo(1234) + exc = cm.exception + self.assertEqual(exc.message, '1234 is not a valid EmbeddedDocument') diff --git a/test/field_types/test_embedded_document_list_field.py b/test/field_types/test_embedded_document_list_field.py new file mode 100644 index 0000000..437e204 --- /dev/null +++ b/test/field_types/test_embedded_document_list_field.py @@ -0,0 +1,78 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from bson import SON + +from pymodm import EmbeddedMongoModel +from pymodm.fields import EmbeddedDocumentListField, CharField + +from test.field_types import FieldTestCase + + +class EmbeddedDocument(EmbeddedMongoModel): + name = CharField() + + class Meta: + final = True + + +class EmbeddedDocumentFieldTestCase(FieldTestCase): + + field = EmbeddedDocumentListField(EmbeddedDocument) + + def test_to_python(self): + # pass a raw list + value = self.field.to_python([{'name': 'Bob'}, {'name': 'Alice'}]) + + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], EmbeddedDocument) + self.assertEqual(value[0].name, 'Bob') + self.assertIsInstance(value[1], EmbeddedDocument) + self.assertEqual(value[1].name, 'Alice') + + # pass a list of models + bob = EmbeddedDocument(name='Bob') + alice = EmbeddedDocument(name='Alice') + value = self.field.to_python([bob, alice]) + + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], EmbeddedDocument) + self.assertEqual(value[0].name, 'Bob') + self.assertIsInstance(value[1], EmbeddedDocument) + self.assertEqual(value[1].name, 'Alice') + + def test_to_mongo(self): + bob = EmbeddedDocument(name='Bob') + alice = EmbeddedDocument(name='Alice') + emb_list = [bob, alice] + value = self.field.to_mongo(emb_list) + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], SON) + self.assertEqual(value[0], SON({'name': 'Bob'})) + self.assertIsInstance(value[1], SON) + self.assertEqual(value[1], SON({'name': 'Alice'})) + + son = value + value = self.field.to_mongo(son) + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], SON) + self.assertEqual(value[0], SON({'name': 'Bob'})) + self.assertIsInstance(value[1], SON) + self.assertEqual(value[1], SON({'name': 'Alice'})) + + value = self.field.to_mongo([{'name': 'Bob'}, alice]) + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], SON) + self.assertEqual(value[0], SON({'name': 'Bob'})) + self.assertIsInstance(value[1], SON) + self.assertEqual(value[1], SON({'name': 'Alice'})) diff --git a/test/test_dereference.py b/test/test_dereference.py index 1ce047b..ceeeaf2 100644 --- a/test/test_dereference.py +++ b/test/test_dereference.py @@ -1,4 +1,4 @@ -from bson.objectid import ObjectId +from bson import ObjectId from pymodm.base import MongoModel, EmbeddedMongoModel from pymodm.context_managers import no_auto_dereference @@ -140,7 +140,7 @@ def test_auto_dereference(self): wrapper_list.wrapper[0].comments[0].post.title ) - def test_unhashable_id(self): + def _test_unhashable_id(self, final_value=True): # Test that we can reference a model whose id type is unhashable # e.g. a dict, list, etc. class CardIdentity(EmbeddedMongoModel): @@ -150,6 +150,9 @@ class CardIdentity(EmbeddedMongoModel): suit = fields.IntegerField( choices=(HEARTS, DIAMONDS, SPADES, CLUBS)) + class Meta: + final = final_value + class Card(MongoModel): id = fields.EmbeddedDocumentField(CardIdentity, primary_key=True) flavor = fields.CharField() @@ -162,8 +165,27 @@ class Hand(MongoModel): Card(CardIdentity(12, CardIdentity.SPADES)).save() ] hand = Hand(cards).save() + + # test auto dereferencing hand.refresh_from_db() - dereference(hand) + self.assertIsInstance(hand.cards[0], Card) + self.assertEqual(hand.cards[0].id.rank, 4) + self.assertIsInstance(hand.cards[1], Card) + self.assertEqual(hand.cards[1].id.rank, 12) + + with no_auto_dereference(hand): + hand.refresh_from_db() + dereference(hand) + self.assertIsInstance(hand.cards[0], Card) + self.assertEqual(hand.cards[0].id.rank, 4) + self.assertIsInstance(hand.cards[1], Card) + self.assertEqual(hand.cards[1].id.rank, 12) + + def test_unhashable_id_final_true(self): + self._test_unhashable_id(final_value=True) + + def test_unhashable_id_final_false(self): + self._test_unhashable_id(final_value=False) def test_reference_not_found(self): post = Post(title='title').save() @@ -204,3 +226,90 @@ class Container(MongoModel): 'Aaron') self.assertEqual(container.lst[0].ref.name, 'Aaron') + + def test_embedded_reference_dereference(self): + # Test dereferencing items stored in a + # EmbeddedDocument(ReferenceField(X)) + class OtherModel(MongoModel): + name = fields.CharField() + + class OtherRefModel(EmbeddedMongoModel): + ref = fields.ReferenceField(OtherModel) + + class Container(MongoModel): + emb = fields.EmbeddedDocumentField(OtherRefModel) + + m1 = OtherModel('Aaron').save() + + container = Container(emb=OtherRefModel(ref=m1)) + container.save() + + # Force ObjectIds. + with no_auto_dereference(container): + container.refresh_from_db() + self.assertIsInstance(container.emb.ref, ObjectId) + dereference(container) + self.assertIsInstance(container.emb.ref, OtherModel) + self.assertEqual(container.emb.ref.name, 'Aaron') + + def test_dereference_reference_not_found(self): + post = Post(title='title').save() + comment = Comment(body='this is a comment', post=post).save() + post.delete() + self.assertEqual(Post.objects.count(), 0) + comment.refresh_from_db() + with no_auto_dereference(comment): + self.assertEqual(comment.post, 'title') + dereference(comment) + self.assertIsNone(comment.post) + + def test_dereference_models_with_same_id(self): + class User(MongoModel): + name = fields.CharField(primary_key=True) + + class CommentWithUser(MongoModel): + body = fields.CharField() + post = fields.ReferenceField(Post) + user = fields.ReferenceField(User) + + post = Post(title='Bob').save() + user = User(name='Bob').save() + + comment = CommentWithUser( + body='this is a comment', + post=post, + user=user).save() + + comment.refresh_from_db() + with no_auto_dereference(CommentWithUser): + dereference(comment) + self.assertIsInstance(comment.post, Post) + self.assertIsInstance(comment.user, User) + + def test_dereference_missed_reference_field(self): + comment = Comment(body='Body Comment').save() + with no_auto_dereference(comment): + comment.refresh_from_db() + dereference(comment) + self.assertIsNone(comment.post) + + def test_dereference_dereferenced_reference(self): + class CommentContainer(MongoModel): + ref = fields.ReferenceField(Comment) + + post = Post(title='title').save() + comment = Comment(body='Comment Body', post=post).save() + + container = CommentContainer(ref=comment).save() + + with no_auto_dereference(comment), no_auto_dereference(container): + comment.refresh_from_db() + container.refresh_from_db() + container.ref = comment + self.assertEqual(container.ref.post, 'title') + dereference(container) + self.assertIsInstance(container.ref.post, Post) + self.assertEqual(container.ref.post.title, 'title') + dereference(container) + self.assertIsInstance(container.ref.post, Post) + self.assertEqual(container.ref.post.title, 'title') diff --git a/test/test_related_fields.py b/test/test_related_fields.py index 1d229c5..c28fdbf 100644 --- a/test/test_related_fields.py +++ b/test/test_related_fields.py @@ -57,7 +57,7 @@ def test_basic_reference(self): def test_assign_id_to_reference_field(self): # No ValidationError raised. - Comment(post=1234).full_clean() + Comment(post="58b477046e32ab215dca2b57").full_clean() def test_validate_embedded_document(self): with self.assertRaisesRegex(ValidationError, 'field is required'): @@ -85,7 +85,7 @@ def test_reference_errors(self): message = cm.exception.message self.assertIn('post', message) self.assertEqual( - ['Referenced documents must be saved to the database first.'], + ['Referenced Models must be saved to the database first.'], message['post']) # Cannot save document when reference is unresolved. @@ -93,7 +93,7 @@ def test_reference_errors(self): comment.save() self.assertIn('post', message) self.assertEqual( - ['Referenced documents must be saved to the database first.'], + ['Referenced Models must be saved to the database first.'], message['post']) def test_embedded_document(self):