Skip to content

Commit

Permalink
if config_extra_fields == 'ignore', store extra fields in the object …
Browse files Browse the repository at this point in the history
…so they can be viewed and saved later if the object is saved
  • Loading branch information
Jeff Jenkins committed Oct 8, 2010
1 parent b3de1bd commit 881ec01
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
18 changes: 15 additions & 3 deletions mongoalchemy/document.py
Expand Up @@ -110,6 +110,8 @@ def __init__(self, retrieved_fields=None, **kwargs):
self.partial = retrieved_fields != None
self.retrieved_fields = self.__normalize(retrieved_fields)

self.__extra_fields = {}

cls = self.__class__

fields = self.get_fields()
Expand All @@ -130,11 +132,13 @@ def __init__(self, retrieved_fields=None, **kwargs):

if hasattr(field, 'default'):
setattr(self, name, field.default)

if self.config_extra_fields != 'ignore':

for k in kwargs:
if k not in fields:
raise ExtraValueException(k)
if self.config_extra_fields == 'ignore':
self.__extra_fields[k] = kwargs[k]
else:
raise ExtraValueException(k)

def __setattr__(self, name, value):
cls = self.__class__
Expand Down Expand Up @@ -172,6 +176,9 @@ def __getattribute__(self, name):
'''

def get_extra_fields(self):
return self.__extra_fields

@classmethod
def get_fields(cls):
'''Returns a dict mapping the names of the fields in a document
Expand Down Expand Up @@ -247,6 +254,8 @@ def wrap(self):
be saved into a mongo database. This is done by using the ``wrap()``
methods of the underlying fields to set values.'''
res = {}
for k, v in self.__extra_fields.iteritems():
res[k] = v
cls = self.__class__
for name in dir(cls):
field = getattr(cls, name)
Expand Down Expand Up @@ -287,6 +296,9 @@ def unwrap(cls, obj, fields=None):
params = {}
for k, v in obj.iteritems():
k = name_reverse.get(k, k)
if not hasattr(cls, k) and cls.config_extra_fields:
params[str(k)] = v
continue
field = getattr(cls, k)
if fields != None and isinstance(field, DocumentField):
normalized_fields = cls.__normalize(fields)
Expand Down
10 changes: 10 additions & 0 deletions test/test_documents.py
Expand Up @@ -53,6 +53,16 @@ def bad_extra_fields_param_test():
class BadDoc(Document):
config_extra_fields = 'blah'

def extra_fields_test():
class BadDoc(Document):
config_extra_fields = 'ignore'
doc_with_extra = {'foo' : [1]}

unwrapped = BadDoc.unwrap(doc_with_extra)
assert unwrapped.get_extra_fields() == doc_with_extra

assert BadDoc.wrap(unwrapped) == doc_with_extra


@raises(MissingValueException)
def test_required_fields():
Expand Down

0 comments on commit 881ec01

Please sign in to comment.