Permalink
Browse files

Add "ignore_missing" to make a proxy iterator ignore missing values. …

…Later this will power a feature which lazily deletes old references.
  • Loading branch information...
1 parent 86bb9b6 commit 2e773c2361ce17507de2d9b0044fc9414ba36eb8 @jeffjenkins committed Aug 3, 2012
Showing with 64 additions and 17 deletions.
  1. +7 −3 mongoalchemy/document.py
  2. +7 −1 mongoalchemy/fields.py
  3. +2 −1 mongoalchemy/query.py
  4. +23 −12 mongoalchemy/session.py
  5. +25 −0 test/test_ref_field.py
View
@@ -99,7 +99,7 @@ def __new__(mcs, classname, bases, class_dict):
if field.proxy is not None:
setattr(new_class, field.proxy, Proxy(name))
if field.iproxy is not None:
- setattr(new_class, field.iproxy, IProxy(name))
+ setattr(new_class, field.iproxy, IProxy(name, field.ignore_missing))
# 4. Add subclasses
@@ -705,8 +705,9 @@ def __set__(self, instance, value):
setattr(instance, self.name, value)
class IProxy(object):
- def __init__(self, name):
+ def __init__(self, name, ignore_missing):
self.name = name
+ self.ignore_missing = ignore_missing
def __get__(self, instance, owner):
if instance is None:
return getattr(owner, self.name)
@@ -716,5 +717,8 @@ def iterator():
if v is None:
yield v
continue
- yield session.dereference(v)
+ value = session.dereference(v, allow_none=self.ignore_missing)
+ if value is None and self.ignore_missing:
+ continue
+ yield value
return iterator()
View
@@ -138,7 +138,7 @@ class Field(object):
def __init__(self, required=True, default=UNSET, db_field=None, allow_none=False, on_update='$set',
validator=None, unwrap_validator=None, wrap_validator=None, _id=False,
- proxy=None, iproxy=None):
+ proxy=None, iproxy=None, ignore_missing=False):
'''
:param required: The field must be passed when constructing a document (optional. default: ``True``)
:param default: Default value to use if one is not given (optional.)
@@ -170,6 +170,7 @@ def __init__(self, required=True, default=UNSET, db_field=None, allow_none=False
self.proxy = proxy
self.iproxy = iproxy
+ self.ignore_missing = ignore_missing
self.validator = validator
self.unwrap_validator = unwrap_validator
@@ -448,6 +449,9 @@ def validate_wrap(self, value):
class DateTimeField(PrimitiveField):
''' Field for datetime objects. '''
+
+ has_autoload = True
+
def __init__(self, min_date=None, max_date=None, use_tz=False, **kwargs):
''' :param max_date: maximum date
:param min_date: minimum date
@@ -478,6 +482,8 @@ def unwrap(self, value, session=None):
if value.tzinfo is not None:
import pytz
value = value.replace(tzinfo=pytz.utc)
+ if session and session.timezone:
+ value = value.astimezone(session.timezone)
return value
def localize(self, session, value):
View
@@ -391,7 +391,8 @@ def next(self):
if obj:
return obj
value = self.type.unwrap(value, fields=self.fields, session=self.session)
- self.session.cache_write(value)
+ if not isinstance(value, dict):
+ self.session.cache_write(value)
return value
def __getitem__(self, index):
View
@@ -52,7 +52,7 @@
class Session(object):
- def __init__(self, database, timezone=None, safe=False, cache_size=0):
+ def __init__(self, database, tz_aware=False, timezone=None, safe=False, cache_size=0):
'''
Create a session connecting to `database`.
@@ -65,11 +65,18 @@ def __init__(self, database, timezone=None, safe=False, cache_size=0):
* db: the underlying pymongo database object
* queue: the queue of unflushed database commands (currently useless \
since there aren't any operations which defer flushing)
+ * cache_size: The size of the identity map to keep. When objects \
+ are pulled from the DB they are checked against this \
+ map and if present, the existing object is used. \
+ Defaults to 0, use None to only clear at session end.
+
'''
self.db = database
self.queue = []
self.safe = safe
self.timezone = timezone
+ self.tz_aware = bool(tz_aware or timezone)
+
self.cache_size = cache_size
self.cache = {}
self.transactions = []
@@ -79,9 +86,6 @@ def autoflush(self):
@property
def in_transaction(self):
return len(self.transactions) > 0
- @property
- def tz_aware(self):
- return self.timezone is not None
@classmethod
def connect(self, database, timezone=None, cache_size=0, *args, **kwds):
@@ -104,17 +108,20 @@ def connect(self, database, timezone=None, cache_size=0, *args, **kwds):
db = conn[database]
return Session(db, timezone=timezone, safe=safe, cache_size=cache_size)
- def cache_write(self, obj):
+ def cache_write(self, obj, mongo_id=None):
+ if mongo_id is None:
+ mongo_id = obj.mongo_id
+
if self.cache_size == 0:
return
- if obj.mongo_id in self.cache:
+ if mongo_id in self.cache:
return
- if len(self.cache) >= self.cache_size:
+ if self.cache_size is not None and len(self.cache) >= self.cache_size:
for key in self.cache:
break
del self.cache[key]
- assert isinstance(obj.mongo_id, ObjectId), 'Currently, cached objects must use mongo_id as an ObjectId'
- self.cache[obj.mongo_id] = obj
+ assert isinstance(mongo_id, ObjectId), 'Currently, cached objects must use mongo_id as an ObjectId. Got: %s' % type(mongo_id)
+ self.cache[mongo_id] = obj
def cache_read(self, id):
if self.cache_size == 0:
@@ -372,7 +379,7 @@ def flush(self, safe=None):
self.clear_queue()
return result
- def dereference(self, ref):
+ def dereference(self, ref, allow_none=False):
if isinstance(ref, Document):
return ref
assert hasattr(ref, 'type')
@@ -381,8 +388,12 @@ def dereference(self, ref):
if obj is not None:
return obj
value = self.db.dereference(ref)
- obj = ref.type.unwrap(value, session=self)
- self.cache_write(obj)
+ if value is None and allow_none:
+ obj = None
+ self.cache_write(obj, mongo_id=ref.id)
+ else:
+ obj = ref.type.unwrap(value, session=self)
+ self.cache_write(obj)
return obj
def refresh(self, document):
View
@@ -62,6 +62,31 @@ class A(Document):
a_set = A()
a_set.x = b
+def test_proxy_ignore_missing():
+ class B(Document):
+ b = IntField(default=3)
+ class A(Document):
+ x_ids = ListField(RefField(B), iproxy='xs', default_empty=True,
+ ignore_missing=True)
+ x_id = RefField(B, proxy='x')
+
+ s = get_session()
+ a = A()
+ for i in range(0, 3):
+ b = B(b=i)
+ b.mongo_id = ObjectId()
+ if i > 0:
+ s.insert(b)
+
+ a.x_id = b
+ a.x_ids.append(b)
+
+ s.insert(a)
+ aa = s.query(A).one()
+
+ assert len(list(aa.xs)) == 2, len(list(aa.xs))
+
+
def test_reffield():
class A(Document):
x = IntField()

0 comments on commit 2e773c2

Please sign in to comment.