Browse files

Gut the internal value/dirty representation and replace it with an ob…

…ject
  • Loading branch information...
1 parent ef67f8c commit 44f0e4e25ef9364fdffa5ba8c656617d1002b5a5 @jeffjenkins committed Aug 6, 2012
View
109 mongoalchemy/document.py
@@ -96,13 +96,6 @@ def __new__(mcs, classname, bases, class_dict):
name = new_class.__name__
document_type_registry[new_class.config_namespace][name] = new_class
- # 5. Add proxies
- for name, field in new_class.get_fields().iteritems():
- if field.proxy is not None:
- setattr(new_class, field.proxy, Proxy(name))
- if field.iproxy is not None:
- setattr(new_class, field.iproxy, IProxy(name, field.ignore_missing))
-
# 4. Add subclasses
for b in bases:
@@ -181,25 +174,27 @@ def __init__(self, retrieved_fields=None, loading_from_db=False, **kwargs):
self.partial = retrieved_fields is not None
self.retrieved_fields = self.__normalize(retrieved_fields)
- self._dirty = {}
-
- self._field_values = {}
+ # Mapping from attribute names to values.
+ self._values = {}
self.__extra_fields = {}
cls = self.__class__
fields = self.get_fields()
for name, field in fields.iteritems():
+ print name
if self.partial and field.db_field not in self.retrieved_fields:
- continue
-
- if name in kwargs:
- getattr(cls, name).set_value(self, kwargs[name], from_db=loading_from_db)
- continue
- # DO I NEED THIS?
- # elif field.default != UNSET:
- # getattr(cls, name).set_value(self, field.default, from_db=loading_from_db)
-
+ self._values[name] = Value(field, self, retrieved=False)
+ elif name in kwargs:
+ field = getattr(cls, name)
+ value = kwargs[name]
+ self._values[name] = Value(field, self,
+ from_db=loading_from_db)
+ getattr(cls, name).set_value(self, kwargs[name])
+ elif field.auto:
+ self._values[name] = Value(field, self, from_db=False)
+ else:
+ self._values[name] = Value(field, self, from_db=False)
for k in kwargs:
if k not in fields:
@@ -255,6 +250,8 @@ def __eq__(self, other):
return self.mongo_id == other.mongo_id
except:
return False
+ def __ne__(self, other):
+ return not self.__eq__(other)
def get_dirty_ops(self, with_required=False):
''' Returns a dict with the update operations necessary to make the
@@ -272,7 +269,7 @@ def get_dirty_ops(self, with_required=False):
continue
dirty_ops = field.dirty_ops(self)
if not dirty_ops and with_required and field.required:
- dirty_ops = field.update_ops(self)
+ dirty_ops = field.update_ops(self, force=True)
if not dirty_ops:
raise MissingValueException(name)
@@ -435,7 +432,7 @@ def unwrap(cls, obj, fields=None, session=None):
if fields is not None:
params['retrieved_fields'] = fields
obj = cls(loading_from_db=True, **params)
- obj.__mark_clean()
+ obj._mark_clean()
obj._session = session
return obj
@@ -445,8 +442,10 @@ def _get_session(self):
def _set_session(self, session):
self._session = session
- def __mark_clean(self):
- self._dirty.clear()
+ def _mark_clean(self):
+ print 'CLEAR DIRTY'
+ for k, v in self._values.iteritems():
+ v.clear_dirty()
class DictDoc(object):
@@ -456,8 +455,8 @@ class DictDoc(object):
'''
def __getitem__(self, name):
''' Gets the field ``name`` from the document '''
- fields = self.get_fields()
- if name in fields:
+ # fields = self.get_fields()
+ if name in self._values:
return getattr(self, name)
raise KeyError(name)
@@ -587,37 +586,31 @@ def ensure(self, collection):
collection.ensure_index(self.components, unique=self.__unique,
drop_dups=self.__drop_dups, **extras)
return self
+
+class Value(object):
+ def __init__(self, field, document, from_db=False, extra=False,
+ retrieved=True):
+ # Stuff
+ self.field = field
+ self.doc = document
+ self.value = None
-class Proxy(object):
- def __init__(self, name):
- self.name = name
- def __get__(self, instance, owner):
- if instance is None:
- return getattr(owner, self.name)
- session = instance._get_session()
- ref = getattr(instance, self.name)
- if ref is None:
- return None
- return session.dereference(ref)
- def __set__(self, instance, value):
- assert instance is not None
- setattr(instance, self.name, value.to_ref())
-
-class IProxy(object):
- def __init__(self, name, ignore_missing):
- self.name = name
- self.ignore_missing = ignore_missing
- def __get__(self, instance, owner):
- if instance is None:
- return getattr(owner, self.name)
- session = instance._get_session()
- def iterator():
- for v in getattr(instance, self.name):
- if v is None:
- yield v
- continue
- value = session.dereference(v, allow_none=self.ignore_missing)
- if value is None and self.ignore_missing:
- continue
- yield value
- return iterator()
+ # Flags
+ self.from_db = from_db
+ self.set = False
+ self.extra = extra
+ self.dirty = False
+ self.retrieved = retrieved
+ self.update_op = None
+ def clear_dirty(self):
+ self.dirty = False
+ self.update_op = None
+
+ def delete(self):
+ self.value = None
+ self.set = False
+ self.dirty = True
+ self.from_db = False
+ self.update_op = '$unset'
+
+
View
61 mongoalchemy/fields/base.py
@@ -193,12 +193,19 @@ def __init__(self, required=True, default=UNSET, db_field=None, allow_none=False
def __get__(self, instance, owner):
if instance is None:
return QueryField(self)
- if self._name in instance._field_values:
- return instance._field_values[self._name]
+ obj_value = instance._values[self._name]
+
+ # if the value is set, just return it
+ if obj_value.set:
+ return instance._values[self._name].value
+
+ # if not, try the default
if self.default is not UNSET:
self.set_value(instance, self.default)
- return instance._field_values[self._name]
- if instance.partial and self.db_field not in instance.retrieved_fields:
+ return instance._values[self._name].value
+
+ # If this value wasn't retrieved, raise a specific exception
+ if not obj_value.retrieved:
raise FieldNotRetrieved(self._name)
raise AttributeError(self._name)
@@ -207,38 +214,48 @@ def __get__(self, instance, owner):
def __set__(self, instance, value):
self.set_value(instance, value)
- def set_value(self, instance, value, from_db=False):
+ def set_value(self, instance, value):
self.validate_wrap(value)
- instance._field_values[self._name] = value
+ obj_value = instance._values[self._name]
+ obj_value.value = value
+ obj_value.dirty = True
+ obj_value.set = True
+ obj_value.from_db = False
if self.on_update != 'ignore':
- instance._dirty[self._name] = self.on_update
+ obj_value.update_op = self.on_update
def dirty_ops(self, instance):
- op = instance._dirty.get(self._name)
- if op == '$unset':
+ obj_value = instance._values[self._name]
+ # op = instance._dirty.get(self._name)
+ if obj_value.update_op == '$unset':
return { '$unset' : { self._name : True } }
- if op is None:
+ if obj_value.update_op is None:
return {}
return {
- op : {
- self.db_field : self.wrap(instance._field_values[self._name])
+ obj_value.update_op : {
+ self.db_field : self.wrap(obj_value.value)
}
}
def __delete__(self, instance):
- if self._name not in instance._field_values:
+ obj_value = instance._values[self._name]
+ if not obj_value.set:
raise AttributeError(self._name)
- del instance._field_values[self._name]
- instance._dirty[self._name] = '$unset'
+ obj_value.delete()
+ # if self._name not in instance._field_values:
+ # raise AttributeError(self._name)
+ # del instance._field_values[self._name]
+ # instance._dirty[self._name] = '$unset'
- def update_ops(self, instance):
- if self._name not in instance._field_values:
- return {}
- return {
- self.on_update : {
- self._name : self.wrap(instance._field_values[self._name])
+ def update_ops(self, instance, force=False):
+ obj_value = instance._values[self._name]
+ if obj_value.set and (obj_value.dirty or force):
+ return {
+ self.on_update : {
+ self._name : self.wrap(obj_value.value)
+ }
}
- }
+ return {}
def localize(self, session, value):
return value
View
14 mongoalchemy/fields/document_field.py
@@ -47,22 +47,24 @@ def type(self):
def dirty_ops(self, instance):
''' Returns a dict of the operations needed to update this object.
See :func:`Document.get_dirty_ops` for more details.'''
- try:
- document = getattr(instance, self._name)
- except AttributeError:
+ print 'check dirty'
+ obj_value = instance._values[self._name]
+ if not obj_value.set:
+ print 'not set'
return {}
- if len(document._dirty) == 0 and \
- self.__type.config_extra_fields != 'ignore':
+
+ if not obj_value.dirty and self.__type.config_extra_fields != 'ignore':
return {}
- ops = document.get_dirty_ops()
+ ops = obj_value.value.get_dirty_ops()
ret = {}
for op, values in ops.iteritems():
ret[op] = {}
for key, value in values.iteritems():
name = '%s.%s' % (self._name, key)
ret[op][name] = value
+ print 'ret'
return ret
def subfields(self):
View
14 mongoalchemy/fields/fields.py
@@ -446,15 +446,17 @@ def __get__(self, instance, owner):
if instance is None:
return QueryField(self)
- if self._name in instance._field_values and self.one_time:
- return instance._field_values[self._name]
+ obj_value = instance._values[self._name]
+ if obj_value.set and self.one_time:
+ return obj_value.value
computed_value = self.compute_value(instance)
if self.one_time:
- instance._field_values[self._name] = computed_value
+ self.set_value(instance, computed_value)
return computed_value
def __set__(self, instance, value):
- if self._name in instance._field_values and self.one_time:
+ obj_value = instance._values[self._name]
+ if obj_value.set and self.one_time:
raise BadValueException(self._name, value, 'Cannot set a one-time field once it has been set')
super(ComputedField, self).__set__(instance, value)
@@ -464,7 +466,9 @@ def set_parent_on_subtypes(self, parent):
def dirty_ops(self, instance):
dirty = False
for dep in self.deps:
- if dep._name in instance._dirty:
+ dep_value = instance._values[dep._name]
+ if dep_value.dirty:
+ dirty = True
break
else:
if len(self.deps) > 0:
View
49 mongoalchemy/fields/ref.py
@@ -21,9 +21,12 @@
# THE SOFTWARE.
from mongoalchemy.fields.base import *
+from bson import DBRef
+class RefBase(Field):
+ pass
-class SRefField(Field):
+class SRefField(RefBase):
''' A Simple RefField (SRefField) looks like an ObjectIdField in the
database, but acts like a mongo DBRef. It uses the passed in type to
determine where to look for the object (and assumes the current
@@ -39,6 +42,12 @@ def __init__(self, type, **kwargs):
self.type = type
if not isinstance(type, DocumentField):
self.type = DocumentField(type)
+ def rel(self):
+ return Proxy(self)
+ def dereference(self, session, ref):
+ ref = DBRef(id=ref, collection=self.type.type.get_collection_name())
+ ref.type = self.type.type
+ return session.dereference(ref)
def wrap(self, value):
self.validate_wrap(value)
return value
@@ -51,7 +60,7 @@ def validate_wrap(self, value):
validate_unwrap = validate_wrap
-class RefField(Field):
+class RefField(RefBase):
''' A ref field wraps a mongo DBReference. It DOES NOT currently handle
saving the referenced object or updates to it, but it can handle
auto-loading.
@@ -97,6 +106,11 @@ def unwrap(self, value, fields=None, session=None):
value.type = self.type
return value
+ def rel(self):
+ return Proxy(self)
+ def dereference(self, session, ref):
+ return session.dereference(ref)
+
def validate_unwrap(self, value, session=None):
''' Validates that the DBRef is valid as well as can be done without
retrieving it.
@@ -113,3 +127,34 @@ def validate_unwrap(self, value, session=None):
self._fail_validation(value, '''Wrong database for reference: '''
''' got "%s" instead of "%s" ''' % (value.database, self.db) )
validate_wrap = validate_unwrap
+
+class Proxy(object):
+ def __init__(self, field):
+ self.field = field
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.field
+ session = instance._get_session()
+ ref = getattr(instance, self.field._name)
+ if ref is None:
+ return None
+ return self.field.dereference(session, ref)
+ def __set__(self, instance, value):
+ assert instance is not None
+ setattr(instance, self.field._name, value.to_ref())
+
+'''
+class Foo(Document):
+ blah_id = RefField('Bar')
+ blah = blah_id.rel()
+
+
+
+
+
+'''
+
+
+
+
+
View
59 mongoalchemy/fields/sequence.py
@@ -113,27 +113,23 @@ def validate_unwrap(self, value, session=None):
self._validate_child_unwrap(v)
- def set_value(self, instance, value, from_db=False):
- super(SequenceField, self).set_value(instance, value, from_db=from_db)
-
- if from_db:
- # loaded from db, stash it
- if 'orig_values' not in instance.__dict__:
- instance.__dict__['orig_values'] = {}
- instance.__dict__['orig_values'][self._name] = deepcopy(value)
+ def set_value(self, instance, value):
+ super(SequenceField, self).set_value(instance, value)
+ # TODO:2012
+ # value_obj = instance._values[self._name]
+ # if from_db:
+ # # loaded from db, stash it
+ # if 'orig_values' not in instance.__dict__:
+ # instance.__dict__['orig_values'] = {}
+ # instance.__dict__['orig_values'][self._name] = deepcopy(value)
def dirty_ops(self, instance):
+ obj_value = instance._values[self._name]
ops = super(SequenceField, self).dirty_ops(instance)
- if len(ops) == 0:
- # see if the underlying sequence has changed. Overwrite if so
- try:
- if instance._field_values[self._name] != instance.__dict__['orig_values'][self._name]:
- ops = {'$set': {
- self.db_field : self.wrap(instance._field_values[self._name])
- }}
- except KeyError:
- # required field is missing
- pass
+ if len(ops) == 0 and obj_value.set:
+ ops = {'$set': {
+ self.db_field : self.wrap(obj_value.value)
+ }}
return ops
@@ -158,6 +154,11 @@ def get_default(self):
return self._default
default = property(get_default, set_default)
+ def rel(self, ignore_missing=False):
+ from mongoalchemy.fields import RefBase
+ assert isinstance(self.item_type, RefBase)
+ return ListProxy(self, ignore_missing=ignore_missing)
+
def _validate_wrap_type(self, value):
import types
if not any([isinstance(value, list), isinstance(value, tuple),
@@ -193,6 +194,9 @@ def get_default(self):
return self._default
default = property(get_default, set_default)
+ def rel(self):
+ raise NotImplementedError()
+
def _validate_wrap_type(self, value):
if not isinstance(value, set):
self._fail_validation_type(value, set)
@@ -213,3 +217,22 @@ def unwrap(self, value, session=None):
returns them in a set'''
self.validate_unwrap(value)
return set([self.item_type.unwrap(v, session=session) for v in value])
+
+class ListProxy(object):
+ def __init__(self, field, ignore_missing=False):
+ self.field = field
+ self.ignore_missing = ignore_missing
+ def __get__(self, instance, owner):
+ if instance is None:
+ return getattr(owner, self.field._name)
+ session = instance._get_session()
+ def iterator():
+ for v in getattr(instance, self.field._name):
+ if v is None:
+ yield v
+ continue
+ value = session.dereference(v, allow_none=self.ignore_missing)
+ if value is None and self.ignore_missing:
+ continue
+ yield value
+ return iterator()
View
2 mongoalchemy/ops.py
@@ -50,6 +50,7 @@ def __init__(self, trans_id, session, document, safe, id_expression=None, upsert
del self.dirty_ops[current_op][key]
if len(self.dirty_ops[current_op]) == 0:
del self.dirty_ops[current_op]
+ document._mark_clean()
def execute(self):
self.ensure_indexes()
@@ -82,6 +83,7 @@ def __init__(self, trans_id, session, document, safe):
if '_id' not in self.data:
self.data['_id'] = ObjectId()
document.mongo_id = self.data['_id']
+ document._mark_clean()
def execute(self):
self.ensure_indexes()
View
34 test/test_documents.py
@@ -184,6 +184,19 @@ def loading_test():
break
assert td.int1 == t.int1
+def docfield_not_dirty_test():
+ class SuperDoc(Document):
+ int1 = IntField()
+ sub = DocumentField(TestDoc)
+ s = get_session()
+ s.clear_collection(TestDoc, SuperDoc)
+ doc = TestDoc(int1=3)
+ sup = SuperDoc(int1=4, sub=doc)
+ s.insert(sup)
+ s.update(sup)
+
+
+
def docfield_test():
class SuperDoc(Document):
int1 = IntField()
@@ -260,7 +273,7 @@ def wrong_wrap_type_test2():
@raises(ExtraValueException)
def wrong_unwrap_type_test():
- DocA.unwrap({ 'test_doc2' : { 'int1' : 1 } })
+ DocA.unwrap({ 'test_doc2' : { 'int1' : 1 }, 'testdoc' : {'int1' : 1 } })
@raises(MissingValueException)
def test_upsert_with_required():
@@ -285,14 +298,31 @@ class D(Document):
d = s.query(D).one()
s.update(d, upsert=True)
+def test_deepcopy():
+ import copy
+ a = TestDoc(int1=4)
+ b = copy.deepcopy(a)
+ assert id(a) != id(b)
+ assert a.int1 == b.int1
+
+def test_default_eq():
+ a = TestDoc(int1=4)
+ b = TestDoc(int1=4)
+ assert not (a == b)
+ a.mongo_id = ObjectId()
+ b.mongo_id = ObjectId()
+ assert a != b
+ b.mongo_id = a.mongo_id
+ assert a == b
+
def test_unwrapped_is_not_dirty():
class D(Document):
a = IntField()
s = get_session()
s.clear_collection(D)
s.insert(D(a=1))
d = s.query(D).one()
- assert len(d.get_dirty_ops()) == 0, len(d.get_dirty_ops())
+ assert len(d.get_dirty_ops()) == 0, d.get_dirty_ops()
def test_update_with_unset():
class D(Document, DictDoc):
View
31 test/test_ref_field.py
@@ -30,12 +30,32 @@ class C(Document):
# Field Tests
+def test_simple_dereference():
+ print 1111
+ class A(Document):
+ x = IntField()
+ class B(Document):
+ y_id = SRefField(DocumentField(A))
+ y = y_id.rel()
+
+ s = get_session()
+ a = A(x=4)
+ s.insert(a)
+
+ b = B(y_id=a.mongo_id)
+ s.add_to_session(b)
+ assert b.y.x == 4
+
+
+
def test_proxy():
class B(Document):
b = IntField(default=3)
class A(Document):
- x_ids = ListField(RefField(B, allow_none=True), iproxy='xs', default_empty=True, allow_none=True)
- x_id = RefField(B, proxy='x', allow_none=True)
+ x_ids = ListField(RefField(B, allow_none=True), default_empty=True, allow_none=True)
+ xs = x_ids.rel()
+ x_id = RefField(B, allow_none=True)
+ x = x_id.rel()
s = get_session()
a = A()
@@ -62,9 +82,10 @@ def test_proxy_ignore_missing():
class B(Document):
b = IntField(default=3)
class A(Document):
- x_ids = ListField(RefField(B), iproxy='xs', default_empty=True,
- ignore_missing=True)
- x_id = RefField(B, proxy='x')
+ x_ids = ListField(RefField(B), default_empty=True)
+ xs = x_ids.rel(ignore_missing=True)
+ x_id = RefField(B)
+ x = x_id.rel()
s = get_session()
a = A()
View
5 test/test_sequence_fields.py
@@ -58,6 +58,11 @@ def set_wrong_type_test_unwrap():
def set_wrong_child_type_test():
SetField(StringField()).wrap(set([4]))
+@raises(NotImplementedError)
+def set_no_rel_test():
+ SetField(StringField()).rel()
+
+
@raises(Exception)
def set_bad_child_type_test():
SetField(int).wrap(set([4]))
View
1 test/test_update_expressions.py
@@ -83,6 +83,7 @@ def nested_field_set_test():
s.query(T2).set('t.i', 3).upsert().execute()
assert s.query(T2).one().t.i == 3
+
def test_update_safe():
s = get_session()
s.clear_collection(TUnique)

0 comments on commit 44f0e4e

Please sign in to comment.