Permalink
Browse files

updates to Session.dereference and RefField

  • Loading branch information...
1 parent 5d42923 commit d7820bf009952fe1a1a7702dd55ac7aea81645f8 @jeffjenkins committed Sep 10, 2012
Showing with 79 additions and 14 deletions.
  1. +0 −1 mongoalchemy/document.py
  2. +3 −0 mongoalchemy/exceptions.py
  3. +10 −8 mongoalchemy/fields/ref.py
  4. +12 −3 mongoalchemy/session.py
  5. +1 −1 setup.py
  6. +53 −1 test/test_ref_field.py
View
@@ -393,7 +393,6 @@ def wrap(self):
value = getattr(self, name)
res[field.db_field] = field.wrap(value)
except AttributeError, e:
- print e
if field.required:
raise MissingValueException(name)
return res
@@ -11,6 +11,9 @@ def __init__(self, name, value, reason, cause=None):
message = '%s Cause: %s' % (message, cause)
Exception.__init__(self, message)
+class BadReferenceException(Exception):
+ pass
+
class InvalidConfigException(Exception):
''' Raised when a bad value is passed in for a configuration that expects
its values to obey certain constraints.'''
View
@@ -35,18 +35,20 @@ class SRefField(RefBase):
'''
has_subfields = True
has_autoload = True
- def __init__(self, type, **kwargs):
+ def __init__(self, type, db=None, **kwargs):
from mongoalchemy.fields import DocumentField
super(SRefField, self).__init__(**kwargs)
self.type = type
if not isinstance(type, DocumentField):
self.type = DocumentField(type)
+ self.db = db
def _to_ref(self, doc):
return doc.mongo_id
def dereference(self, session, ref, allow_none=False):
- ref = DBRef(id=ref, collection=self.type.type.get_collection_name())
+ ref = DBRef(id=ref, collection=self.type.type.get_collection_name(),
+ database=self.db)
ref.type = self.type.type
return session.dereference(ref, allow_none=allow_none)
def set_parent_on_subtypes(self, parent):
@@ -72,7 +74,7 @@ class RefField(RefBase):
has_subfields = True
has_autoload = True
- def __init__(self, type=None, db=None, namespace='global', **kwargs):
+ def __init__(self, type=None, db=None, db_required=False, namespace='global', **kwargs):
''' :param type: (optional) the Field type to use for the values. It
must be a DocumentField. If you want to save refs to raw mongo
objects, you can leave this field out
@@ -88,6 +90,7 @@ def __init__(self, type=None, db=None, namespace='global', **kwargs):
type = DocumentField(type)
super(RefField, self).__init__(**kwargs)
+ self.db_required = db_required
self.type = type
self.namespace = namespace
self.db = db
@@ -102,7 +105,7 @@ def wrap(self, value):
return value
def _to_ref(self, doc):
- return doc.to_ref()
+ return doc.to_ref(db=self.db)
def unwrap(self, value, fields=None, session=None):
''' If ``autoload`` is False, return a DBRef object. Otherwise load
@@ -118,10 +121,7 @@ def dereference(self, session, ref, allow_none=False):
from mongoalchemy.document import collection_registry
# TODO: namespace support
ref.type = collection_registry['global'][ref.collection]
- # print ref.type, ref
- # print '--' * 30
obj = session.dereference(ref, allow_none=allow_none)
- # print '--' * 30
return obj
def set_parent_on_subtypes(self, parent):
if self.type:
@@ -139,7 +139,9 @@ def validate_unwrap(self, value, session=None):
if expected != got:
self._fail_validation(value, '''Wrong collection for reference: '''
'''got "%s" instead of "%s" ''' % (got, expected))
- if self.db and self.db != value.database:
+ if self.db_required and not value.database:
+ self._fail_validation(value, 'db_required=True, but not database specified')
+ if self.db and value.database and self.db != value.database:
self._fail_validation(value, '''Wrong database for reference: '''
''' got "%s" instead of "%s" ''' % (value.database, self.db) )
validate_wrap = validate_unwrap
View
@@ -45,9 +45,9 @@
from pymongo.connection import Connection
from bson import DBRef, ObjectId
from mongoalchemy.query import Query, QueryResult, RemoveQuery
-from mongoalchemy.document import FieldNotRetrieved, Document
+from mongoalchemy.document import FieldNotRetrieved, Document, collection_registry
from mongoalchemy.query_expression import FreeFormDoc
-from mongoalchemy.exceptions import TransactionException
+from mongoalchemy.exceptions import TransactionException, BadReferenceException
from mongoalchemy.ops import *
class Session(object):
@@ -382,15 +382,24 @@ def flush(self, safe=None):
def dereference(self, ref, allow_none=False):
if isinstance(ref, Document):
return ref
+ if not hasattr(ref, 'type'):
+ if ref.collection in collection_registry['global']:
+ ref.type = collection_registry['global'][ref.collection]
assert hasattr(ref, 'type')
obj = self.cache_read(ref.id)
if obj is not None:
return obj
- value = self.db.dereference(ref)
+ if ref.database and self.db.name != ref.database:
+ db = self.db.connection[ref.database]
+ else:
+ db = self.db
+ value = db.dereference(ref)
if value is None and allow_none:
obj = None
self.cache_write(obj, mongo_id=ref.id)
+ elif value is None:
+ raise BadReferenceException('Bad reference: %r' % ref)
else:
obj = ref.type.unwrap(value, session=self)
self.cache_write(obj)
View
@@ -2,7 +2,7 @@
from distutils.core import setup
-VERSION = '0.12.2'
+VERSION = '0.13.0'
DESCRIPTION = 'Document-Object Mapper/Toolkit for Mongo Databases'
setup(
View
@@ -2,7 +2,8 @@
from mongoalchemy.fields import *
from mongoalchemy.exceptions import DocumentException, MissingValueException, \
- ExtraValueException, FieldNotRetrieved, BadFieldSpecification
+ ExtraValueException, FieldNotRetrieved, BadFieldSpecification, \
+ BadReferenceException
from mongoalchemy.document import Document, document_type_registry
from mongoalchemy.session import Session
from test.util import known_failure
@@ -184,6 +185,57 @@ class A(Document):
ret = RefField(DocumentField(A)).unwrap(5)
+@raises(BadValueException)
+def test_unwrap_missing_db():
+ class A(Document):
+ x = IntField()
+ s = get_session()
+
+ a = A(x=5)
+ s.insert(a)
+
+ aref = {'$id':a.mongo_id, '$ref':'A'}
+ dbaref = DBRef(collection='A', id=a.mongo_id)
+
+ ret = RefField(DocumentField(A), db_required=True).unwrap(dbaref)
+
+def test_dereference_doc():
+ class A(Document):
+ x = IntField()
+
+ s = Session.connect('unit-testing', cache_size=0)
+ s.clear_collection(A)
+
+ a = A(x=5)
+ s.insert(a)
+ dbaref = DBRef(collection='A', id=a.mongo_id, database='unit-testing')
+ s2 = Session.connect('unit-testing2', cache_size=0)
+ assert s2.dereference(a).x == 5
+
+def test_dereference():
+ class A(Document):
+ x = IntField()
+
+ s = Session.connect('unit-testing', cache_size=0)
+ s.clear_collection(A)
+
+ a = A(x=5)
+ s.insert(a)
+ dbaref = DBRef(collection='A', id=a.mongo_id, database='unit-testing')
+ s2 = Session.connect('unit-testing2', cache_size=0)
+ assert s2.dereference(dbaref).x == 5
+
+@raises(BadReferenceException)
+def test_bad_dereference():
+ class A(Document):
+ x = IntField()
+
+ s = Session.connect('unit-testing', cache_size=0)
+ s.clear_collection(A)
+ dbaref = DBRef(collection='A', id=ObjectId(), database='unit-testing')
+ s.dereference(dbaref)
+
+
def test_simple():
class A(Document):
x = IntField()

0 comments on commit d7820bf

Please sign in to comment.