Permalink
Browse files

Allow various comparsions (not only equality) when filtering

  • Loading branch information...
1 parent 4f0861e commit c8714ae3a79ac82caf24a6aa25f6e8ff36d4cff2 @drdaeman committed Dec 9, 2012
Showing with 39 additions and 8 deletions.
  1. +1 −1 README.rst
  2. +20 −1 flask_toybox/sqlalchemy.py
  3. +18 −6 tests/test_sqlalchemy.py
View
@@ -27,13 +27,13 @@ What's implemented:
TastyPie-like hydration/dehydration methods.
- SQLAlchemy model and collection views support. Best used with Flask-SQLAlchemy.
- Simple pagination helper (pagination using "Range" request header).
+- Built-in helper for filtering SQLAlchemy collections.
What's missing:
- Better example.
- Documentation. There are some docstrings in source code, but not much.
- POST, PUT and DELETE requests (object creation and deletion).
-- Built-in helpers for filtering SQLAlchemy collections.
- Overriding negotiation using query string (i.e. ``?format=json``)
- Nested resources.
- Better test coverage.
View
@@ -21,6 +21,7 @@
from flask import g, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.datastructures import Range, ContentRange
+import operator
import json
def column_info(model, name, column):
@@ -180,6 +181,20 @@ class QueryFiltering(object):
Note, filtering is allowed only on class-level readable fields, as returned
by `check_class_permissions`. Other query arguments are silently ignored.
"""
+ def decode_filter(self, name, value):
+ OPERATOR_MAP = {"eq:": operator.eq, "ne:": operator.ne,
+ "lt:": operator.lt, "le:": operator.le,
+ "gt:": operator.gt, "ge:": operator.ge}
+
+ op = operator.eq
+ if len(value) >= 3 and value[:3] in OPERATOR_MAP:
+ op, value = OPERATOR_MAP[value[:3]], value[3:]
+ try:
+ value = json.loads(value)
+ except ValueError:
+ pass
+ return (op, value)
+
def get_query(self):
q = super(QueryFiltering, self).get_query()
columns = self.model.get_columns(only_permitted="readable")
@@ -188,7 +203,11 @@ def get_query(self):
for name, values in request.args.lists():
if name in columns:
c = getattr(self.model, name)
- f = [c == json.loads(value) for value in values]
+ f = []
+ for value in values:
+ op, value = self.decode_filter(name, value)
+ if op is not None:
+ f.append(op(c, value))
q = q.filter(*f)
return q
View
@@ -135,12 +135,24 @@ def test_collection_pagination(self):
self.assertEqual(usernames, ["ham", "spam"])
def test_collection_filtering(self):
- # This also tests whenever is_admin will be ignored, as it is not readable.
- response = self.app.get("/users/?is_staff=true&is_admin=true&spam=spam", headers={"Accept": "application/json"})
- self.assertEqual(response.status_code, 200, response.status)
- data = json.loads(response.data)
- usernames = set([data_item.get("username", None) for data_item in data])
- self.assertEqual(usernames, set(["spam", "eggs"]))
+ cases = [
+ # This also tests whenever is_admin will be ignored, as it is not readable.
+ ("is_staff=true&is_admin=true&spam=spam", set(["spam", "eggs"])),
+ ("badges=lt:2", set(["ham", "spam"])),
+ ("badges=eq:0&is_active=false", set(["ham"])),
+ ("badges=ne:0&is_active=false", set()),
+ ("is_staff=true&is_staff=false", set()),
+ ("badges=ne:null", set(["spam", "ham", "eggs"])),
+ ("is_staff=\"true\"", set()), # XXX: Should it return empty set or error?
+ ("is_staff=invalid", set())
+ ]
+
+ for query, expected in cases:
+ response = self.app.get("/users/?" + query, headers={"Accept": "application/json"})
+ self.assertEqual(response.status_code, 200, response.status)
+ data = json.loads(response.data)
+ usernames = set([data_item.get("username", None) for data_item in data])
+ self.assertEqual(usernames, expected)
def test_collection_is_readonly(self):
for method in ("put", "patch", "delete"):

0 comments on commit c8714ae

Please sign in to comment.