Skip to content

Commit

Permalink
Ported fix for Circular Reference bug to Master
Browse files Browse the repository at this point in the history
Ready for a 0.5.2 release
  • Loading branch information
rozza committed Oct 12, 2011
1 parent 591149b commit 452bbcc
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ that much better:
* Gareth Lloyd
* Albert Choi
* John Arnfield
* Julien Rebetez

5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Changelog
=========

Changes in v0.5.2
=================

- A Robust Circular reference bugfix

Changes in v0.5.1
=================

Expand Down
25 changes: 19 additions & 6 deletions mongoengine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,18 +724,32 @@ def _mark_as_changed(self, key):
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key)

def _get_changed_fields(self, key=''):
def _get_changed_fields(self, key='', inspected=None):
"""Returns a list of all fields that have explicitly been changed.
"""
from mongoengine import EmbeddedDocument
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
for field_name in self._fields:

inspected = inspected or set()
if hasattr(self, 'id'):
if self.id in inspected:
return _changed_fields
inspected.add(self.id)

field_list = self._fields.copy()

for field_name in field_list:
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
if hasattr(field, 'id'):
if field.id in inspected:
continue
inspected.add(field.id)

if isinstance(field, (EmbeddedDocument,)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(field, 'items'):
Expand All @@ -746,8 +760,7 @@ def _get_changed_fields(self, key=''):
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]

_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
return _changed_fields

def _delta(self):
Expand Down
45 changes: 45 additions & 0 deletions tests/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,51 @@ def __repr__(self):

self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())

def test_circular_tree_reference(self):
"""Ensure you can handle circular references with more than one level
"""
class Other(EmbeddedDocument):
name = StringField()
friends = ListField(ReferenceField('Person'))

class Person(Document):
name = StringField()
other = EmbeddedDocumentField(Other, default=lambda: Other())

def __repr__(self):
return "<Person: %s>" % self.name

Person.drop_collection()
paul = Person(name="Paul")
paul.save()
maria = Person(name="Maria")
maria.save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()

paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"
paul.save()

maria.other.friends = [paul, julia, anna]
maria.other.name = "Maria's friends"
maria.save()

julia.other.friends = [paul, maria, anna]
julia.other.name = "Julia's friends"
julia.save()

anna.other.friends = [paul, maria, julia]
anna.other.name = "Anna's friends"
anna.save()

self.assertEquals(
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
"%s" % Person.objects()
)

def test_generic_reference(self):

class UserA(Document):
Expand Down

0 comments on commit 452bbcc

Please sign in to comment.