Permalink
Browse files

smarter dereferencing

  • Loading branch information...
1 parent 44f0e4e commit 5d42923b9172029b099a6b9541567bdc3597212b @jeffjenkins committed Aug 9, 2012
View
@@ -57,7 +57,7 @@ Interesting Features
* **Drop into raw Mongo** — Most functions will accept raw pymongo instead of the mongoalchemy objects. \
For example::
- session.query('SomeClass').filter(SomeClass.name == foo).limit(5)``
+ session.query('SomeClass').filter(SomeClass.name == foo).limit(5)
is perfectly valid, as is::
View
@@ -49,6 +49,7 @@
from mongoalchemy.exceptions import DocumentException, MissingValueException, ExtraValueException, FieldNotRetrieved, BadFieldSpecification
document_type_registry = defaultdict(dict)
+collection_registry = defaultdict(dict)
class DocumentMeta(type):
def __new__(mcs, classname, bases, class_dict):
@@ -89,22 +90,33 @@ def __new__(mcs, classname, bases, class_dict):
continue
new_class._fields[name] = maybefield
- # 3. register type
- if new_class.config_namespace is not None:
- name = new_class.config_full_name
- if name == None:
- name = new_class.__name__
- document_type_registry[new_class.config_namespace][name] = new_class
-
- # 4. Add subclasses
-
+ # 3. Add subclasses
for b in bases:
if 'Document' in globals() and issubclass(b, Document):
b.add_subclass(new_class)
if not hasattr(b, 'config_polymorphic_collection'):
continue
- if b.config_polymorphic_collection and 'config_collection' not in class_dict:
+ if b.config_polymorphic_collection and 'config_collection_name' not in class_dict:
new_class.config_collection_name = b.get_collection_name()
+
+ # 4. register type
+ if new_class.config_namespace is not None:
+ name = new_class.config_full_name
+ if name == None:
+ name = new_class.__name__
+
+ ns = new_class.config_namespace
+ document_type_registry[ns][name] = new_class
+
+ # if the new class uses a polymorphic collection we should only
+ # set up the collection name to refer to the base class
+ # TODO: if non-polymorphic classes use the collection registry they
+ # will just overwrite for now.
+ collection = new_class.get_collection_name()
+ current = collection_registry[ns].get(collection)
+ if current is None or issubclass(current, new_class):
+ collection_registry[ns][collection] = new_class
+
return new_class
@@ -182,7 +194,7 @@ def __init__(self, retrieved_fields=None, loading_from_db=False, **kwargs):
fields = self.get_fields()
for name, field in fields.iteritems():
- print name
+ # print name
if self.partial and field.db_field not in self.retrieved_fields:
self._values[name] = Value(field, self, retrieved=False)
elif name in kwargs:
@@ -443,7 +455,7 @@ def _set_session(self, session):
self._session = session
def _mark_clean(self):
- print 'CLEAR DIRTY'
+ # print 'CLEAR DIRTY'
for k, v in self._values.iteritems():
v.clear_dirty()
@@ -37,7 +37,7 @@ def type(self):
from mongoalchemy.document import Document, document_type_registry
if not isinstance(self.__type, basestring) and issubclass(self.__type, Document):
return self.__type
- if self.parent.config_namespace == None:
+ if self.parent and self.parent.config_namespace == None:
raise BadFieldSpecification('Document namespace is None. Strings are not allowed for DocumentFields')
type = document_type_registry[self.parent.config_namespace].get(self.__type)
if type == None or not issubclass(type, Document):
@@ -47,10 +47,10 @@ def type(self):
def dirty_ops(self, instance):
''' Returns a dict of the operations needed to update this object.
See :func:`Document.get_dirty_ops` for more details.'''
- print 'check dirty'
+ # print 'check dirty'
obj_value = instance._values[self._name]
if not obj_value.set:
- print 'not set'
+ # print 'not set'
return {}
if not obj_value.dirty and self.__type.config_extra_fields != 'ignore':
@@ -64,7 +64,7 @@ def dirty_ops(self, instance):
for key, value in values.iteritems():
name = '%s.%s' % (self._name, key)
ret[op][name] = value
- print 'ret'
+ # print 'ret'
return ret
def subfields(self):
View
@@ -24,7 +24,8 @@
from bson import DBRef
class RefBase(Field):
- pass
+ def rel(self):
+ return Proxy(self)
class SRefField(RefBase):
''' A Simple RefField (SRefField) looks like an ObjectIdField in the
@@ -42,22 +43,24 @@ def __init__(self, type, **kwargs):
self.type = type
if not isinstance(type, DocumentField):
self.type = DocumentField(type)
- def rel(self):
- return Proxy(self)
- def dereference(self, session, ref):
+ def _to_ref(self, doc):
+ return doc.mongo_id
+ def dereference(self, session, ref, allow_none=False):
ref = DBRef(id=ref, collection=self.type.type.get_collection_name())
ref.type = self.type.type
- return session.dereference(ref)
+ return session.dereference(ref, allow_none=allow_none)
+ def set_parent_on_subtypes(self, parent):
+ self.type.parent = parent
def wrap(self, value):
self.validate_wrap(value)
return value
def unwrap(self, value, fields=None, session=None):
self.validate_unwrap(value)
return value
- def validate_wrap(self, value):
+ def validate_unwrap(self, value, session=None):
if not isinstance(value, ObjectId):
self._fail_validation_type(value, ObjectId)
- validate_unwrap = validate_wrap
+ validate_wrap = validate_unwrap
class RefField(RefBase):
@@ -88,6 +91,7 @@ def __init__(self, type=None, db=None, namespace='global', **kwargs):
self.type = type
self.namespace = namespace
self.db = db
+ self.parent = None
def wrap(self, value):
''' Validate ``value`` and then use the value_type to wrap the
@@ -97,7 +101,9 @@ def wrap(self, value):
value.type = self.type
return value
-
+ def _to_ref(self, doc):
+ return doc.to_ref()
+
def unwrap(self, value, fields=None, session=None):
''' If ``autoload`` is False, return a DBRef object. Otherwise load
the object.
@@ -108,8 +114,18 @@ def unwrap(self, value, fields=None, session=None):
def rel(self):
return Proxy(self)
- def dereference(self, session, ref):
- return session.dereference(ref)
+ def dereference(self, session, ref, allow_none=False):
+ from mongoalchemy.document import collection_registry
+ # TODO: namespace support
+ ref.type = collection_registry['global'][ref.collection]
+ # print ref.type, ref
+ # print '--' * 30
+ obj = session.dereference(ref, allow_none=allow_none)
+ # print '--' * 30
+ return obj
+ def set_parent_on_subtypes(self, parent):
+ if self.type:
+ self.type.parent = parent
def validate_unwrap(self, value, session=None):
''' Validates that the DBRef is valid as well as can be done without
@@ -141,20 +157,5 @@ def __get__(self, instance, owner):
return self.field.dereference(session, ref)
def __set__(self, instance, value):
assert instance is not None
- setattr(instance, self.field._name, value.to_ref())
-
-'''
-class Foo(Document):
- blah_id = RefField('Bar')
- blah = blah_id.rel()
-
-
-
-
-
-'''
-
-
-
-
+ setattr(instance, self.field._name, self.field._to_ref(value))
@@ -60,6 +60,9 @@ def subfields(self):
''' Returns the names of the value type's sub-fields'''
return self.item_type.subfields()
+ def dereference(self, session, ref, allow_none=False):
+ return self.item_type.dereference(session, ref, allow_none=allow_none)
+
def wrap_value(self, value):
''' A function used to wrap a value used in a comparison. It will
first try to wrap as the sequence's sub-type, and then as the
@@ -194,9 +197,9 @@ def get_default(self):
return self._default
default = property(get_default, set_default)
- def rel(self):
- raise NotImplementedError()
-
+ def rel(self, ignore_missing=False):
+ return ListProxy(self, ignore_missing=ignore_missing)
+
def _validate_wrap_type(self, value):
if not isinstance(value, set):
self._fail_validation_type(value, set)
@@ -231,7 +234,8 @@ def iterator():
if v is None:
yield v
continue
- value = session.dereference(v, allow_none=self.ignore_missing)
+ value = self.field.dereference(session, v,
+ allow_none=self.ignore_missing)
if value is None and self.ignore_missing:
continue
yield value
View
@@ -351,9 +351,9 @@ def clear_queue(self, trans_id=None):
for index, op in enumerate(self.queue):
if op.trans_id == trans_id:
break
- print 'GOT INDEX', index
+ # print 'GOT INDEX', index
self.queue = self.queue[:index]
- print '\t', self.queue
+ # print '\t', self.queue
def clear_cache(self):
self.cache = {}
@@ -383,7 +383,7 @@ def dereference(self, ref, allow_none=False):
if isinstance(ref, Document):
return ref
assert hasattr(ref, 'type')
-
+
obj = self.cache_read(ref.id)
if obj is not None:
return obj
View
@@ -32,65 +32,84 @@ class C(Document):
def test_simple_dereference():
print 1111
- class A(Document):
+ class ASD(Document):
x = IntField()
- class B(Document):
- y_id = SRefField(DocumentField(A))
+ class BSD(Document):
+ y_id = SRefField(DocumentField(ASD))
y = y_id.rel()
s = get_session()
- a = A(x=4)
+ s.clear_collection(ASD)
+ s.clear_collection(BSD)
+ a = ASD(x=4)
s.insert(a)
- b = B(y_id=a.mongo_id)
+ b = BSD()
+ b.y = a
s.add_to_session(b)
assert b.y.x == 4
+def test_poly_ref():
+ class PRef(Document):
+ config_polymorphic_collection = True
+ x = IntField()
+
+ class PRef2(PRef):
+ y = IntField()
+ r2 = PRef2()
+ r2.mongo_id = ObjectId()
+ assert RefField()._to_ref(r2).collection == 'PRef'
def test_proxy():
- class B(Document):
+ class TPB(Document):
b = IntField(default=3)
- class A(Document):
- x_ids = ListField(RefField(B, allow_none=True), default_empty=True, allow_none=True)
+ class TPA(Document):
+ x_ids = ListField(RefField(TPB, allow_none=True), default_empty=True, allow_none=True)
xs = x_ids.rel()
- x_id = RefField(B, allow_none=True)
+ x_id = RefField(TPB, allow_none=True)
x = x_id.rel()
s = get_session()
- a = A()
+ s.clear_collection(TPA)
+ s.clear_collection(TPB)
+
+ a = TPA()
for i in range(0, 3):
- b = B(b=i)
+ b = TPB(b=i)
s.insert(b)
a.x_id = b.to_ref()
a.x_ids.append(b.to_ref())
s.insert(a)
- aa = s.query(A).one()
+ aa = s.query(TPA).one()
assert aa.x.b == 2, aa.x.b
assert [z.b for z in aa.xs] == range(0, 3)
- a_none = A(x_id=None, x_ids=[None])
+ a_none = TPA(x_id=None, x_ids=[None])
a_none._set_session(s)
assert a_none.x == None
assert list(a_none.xs) == [None]
- a_set = A()
+ a_set = TPA()
a_set.x = b
def test_proxy_ignore_missing():
- class B(Document):
+ class TPIMB(Document):
b = IntField(default=3)
- class A(Document):
- x_ids = ListField(RefField(B), default_empty=True)
+ class TPIMA(Document):
+ x_ids = ListField(RefField(TPIMB), default_empty=True)
xs = x_ids.rel(ignore_missing=True)
- x_id = RefField(B)
+ x_id = RefField(TPIMB)
x = x_id.rel()
s = get_session()
- a = A()
+ s.clear_collection(TPIMA)
+ s.clear_collection(TPIMB)
+
+ a = TPIMA()
for i in range(0, 3):
- b = B(b=i)
+ b = TPIMB(b=i)
b.mongo_id = ObjectId()
if i > 0:
s.insert(b)
@@ -99,7 +118,7 @@ class A(Document):
a.x_ids.append(b.to_ref())
s.insert(a)
- aa = s.query(A).one()
+ aa = s.query(TPIMA).one()
assert len(list(aa.xs)) == 2, len(list(aa.xs))
@@ -58,7 +58,6 @@ def set_wrong_type_test_unwrap():
def set_wrong_child_type_test():
SetField(StringField()).wrap(set([4]))
-@raises(NotImplementedError)
def set_no_rel_test():
SetField(StringField()).rel()

0 comments on commit 5d42923

Please sign in to comment.