Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Sort via mongo-style params using query.sort()

  • Loading branch information...
commit 147b2fbd07a17ce692b163bd345f34ad8de14657 1 parent 1955fc1
Jeff Jenkins authored
View
24 mongoalchemy/query.py
@@ -48,7 +48,7 @@ def __init__(self, type, session):
self.session = session
self.type = type
self.__query = {}
- self.sort = []
+ self._sort = []
self._fields = None
self.hints = []
self._limit = None
@@ -101,7 +101,7 @@ def clone(self):
'''
qclone = Query(self.type, self.session)
qclone.__query = deepcopy(self.__query)
- qclone.sort = deepcopy(self.sort)
+ qclone._sort = deepcopy(self._sort)
qclone._fields = deepcopy(self._fields)
qclone._hints = deepcopy(self.hints)
qclone._limit = deepcopy(self._limit)
@@ -257,13 +257,29 @@ def descending(self, qfield):
'''
return self.__sort(qfield, DESCENDING)
+ def sort(self, *sort_tuples):
+ ''' pymongo-style sorting. Accepts a list of tuples.
+
+ :param sort_tuples: varargs of sort tuples.
+ '''
+ query = self
+ for name, direction in sort_tuples:
+ field = resolve_name(self.type, name)
+ if direction in (ASCENDING, 1):
+ query = query.ascending(field)
+ elif direction in (DESCENDING, -1):
+ query = query.descending(field)
+ else:
+ raise BadQueryException('Bad sort direction: %s' % direction)
+ return query
+
def __sort(self, qfield, direction):
qfield = resolve_name(self.type, qfield)
name = str(qfield)
- for n, _ in self.sort:
+ for n, _ in self._sort:
if n == name:
raise BadQueryException('Already sorting by %s' % name)
- self.sort.append((name, direction))
+ self._sort.append((name, direction))
return self
def not_(self, *query_expressions):
View
8 mongoalchemy/session.py
@@ -227,8 +227,8 @@ def execute_query(self, query, session):
collection = self.db[query.type.get_collection_name()]
cursor = collection.find(query.query, **kwargs)
- if query.sort:
- cursor.sort(query.sort)
+ if query._sort:
+ cursor.sort(query._sort)
elif query.type.config_default_sort:
cursor.sort(query.type.config_default_sort)
if query.hints:
@@ -306,8 +306,8 @@ def execute_find_and_modify(self, fm_exp):
kwargs['fields'] = {}
for f in fm_exp.query.get_fields():
kwargs['fields'][str(f)] = True
- if fm_exp.query.sort:
- kwargs['sort'] = fm_exp.query.sort
+ if fm_exp.query._sort:
+ kwargs['sort'] = fm_exp.query._sort
if fm_exp.get_new():
kwargs['new'] = fm_exp.get_new()
if fm_exp.get_remove():
View
15 test/test_query.py
@@ -285,10 +285,23 @@ def test_sort():
from pymongo import ASCENDING, DESCENDING
s = get_session()
sorted_query = s.query(T).ascending(T.i).descending(T.j)
- assert sorted_query.sort == [('i', ASCENDING),('j', DESCENDING)], sorted_query.sort
+ assert sorted_query._sort == [('i', ASCENDING),('j', DESCENDING)], sorted_query._sort
for obj in sorted_query:
pass
+def test_sort2():
+ from pymongo import ASCENDING, DESCENDING
+ s = get_session()
+ sorted_query = s.query(T).sort((T.i, ASCENDING), ('j', DESCENDING))
+ assert sorted_query._sort == [('i', ASCENDING),('j', DESCENDING)], sorted_query._sort
+
+@raises(BadQueryException)
+def test_sort_bad_dir():
+ from pymongo import ASCENDING, DESCENDING
+ s = get_session()
+ sorted_query = s.query(T).sort((T.i, ASCENDING), ('j', 4))
+ assert sorted_query._sort == [('i', ASCENDING),('j', DESCENDING)], sorted_query._sort
+
@raises(BadQueryException)
def test_sort_by_same_key():
s = get_session()
Please sign in to comment.
Something went wrong with that request. Please try again.