Permalink
Browse files

Basic filtering using query string arguments. Presently, equality-only.

For this, check_permissions is split into class- and instance-level methods.
  • Loading branch information...
1 parent c54bca5 commit 4f0861ec8f98eb8f72f66859f54f16770bfbf378 @drdaeman committed Dec 9, 2012
Showing with 121 additions and 35 deletions.
  1. +10 −4 example/example/models.py
  2. +59 −15 flask_toybox/sqlalchemy.py
  3. +26 −0 flask_toybox/utils.py
  4. +4 −4 flask_toybox/views.py
  5. +0 −3 tests/test_model.py
  6. +22 −9 tests/test_sqlalchemy.py
View
@@ -16,16 +16,22 @@ class User(db.Model, SAModelMixin):
first_name = db.Column(db.String(64), nullable=False, info=I("r:authenticated+,w:owner+"))
last_name = db.Column(db.String(64), nullable=False, info=I("rw:owner+"))
- def check_permissions(self, user=None):
- user = getattr(g, "user", None)
+ @classmethod
+ def check_class_permissions(self, user=None):
+ user = user or getattr(g, "user", None)
if user is not None:
p = {"authenticated"}
- if user.id == self.id:
- p.add("owner")
return p
else:
return {"anonymous"}
+ def check_instance_permissions(self, user=None):
+ user = user or getattr(g, "user", None)
+ p = self.check_class_permissions(user=user)
+ if user is not None and user.id == self.id:
+ p.add("owner")
+ return p
+
class Post(db.Model, SAModelMixin):
__tablename__ = "posts"
View
@@ -17,9 +17,11 @@
from sqlalchemy.schema import Column
from .views import ModelView, BaseModelView
from .permissions import ModelColumnInfo
+from .utils import mixedmethod
from flask import g, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.datastructures import Range, ContentRange
+import json
def column_info(model, name, column):
return ModelColumnInfo(model, name,
@@ -39,31 +41,41 @@ def _get_permissions(cls, column, what="readable"):
permissions = column.permissions
return permissions.get(what, frozenset(["system"]))
- def get_columns(self, only_db_columns=False):
+ @mixedmethod
+ def get_columns(self, cls, only_db_columns=False, only_permitted=None):
+ # cls = self.__class__
columns = [
column_info(self, prop.key, prop.columns[0])
- for prop in class_mapper(self.__class__).iterate_properties
+ for prop in class_mapper(cls).iterate_properties
if isinstance(prop, ColumnProperty) and len(prop.columns) == 1\
and not prop.key.startswith("_") and\
(not only_db_columns or isinstance(prop.columns[0], Column))
]
if not only_db_columns:
# If there's a mix, "real" DB columns should go first
columns.sort(key=lambda c: c.db_column, reverse=True)
- return columns
-
- def check_permissions(self):
- return frozenset(["system"])
-
- def as_dict(self, check_permissions=True):
- columns = self.get_columns()
- if check_permissions:
- levels = self.check_permissions()
+ if only_permitted is not None:
+ if self is not None:
+ levels = self.check_instance_permissions()
+ else:
+ levels = cls.check_class_permissions()
+ get_perms = cls._get_permissions
columns = [c for c in columns
- if any(l in self.__class__._get_permissions(c)
+ if any(l in get_perms(c, what=only_permitted)
for l in levels)]
+ return columns
+ @classmethod
+ def check_class_permissions(cls, **kwargs):
+ return set(["system"])
+
+ def check_instance_permissions(self, **kwargs):
+ return self.check_class_permissions(**kwargs)
+
+ def as_dict(self, check_permissions=True):
+ check = "readable" if check_permissions else None
+ columns = self.get_columns(only_permitted=check)
return OrderedDict((c.name, getattr(self, c.name)) for c in columns)
@staticmethod
@@ -85,20 +97,29 @@ def hasUserMixin(owner_id_field):
for models referencing Django-like user objects.
"""
class HasUserMixin(object):
- def check_permissions(self, user=None):
+ @classmethod
+ def check_class_permissions(cls, user=None):
if user is None and hasattr(g, "user"):
user = g.user
p = set()
if user is not None:
p.add("authenticated")
- if owner_id_field is not None:
- if user.id == getattr(self, owner_id_field): p.add("owner")
if getattr(user, "is_superuser", False): p.add("admin")
if getattr(user, "is_staff", False): p.add("staff")
else:
p.add("anonymous")
return p
+
+ def check_instance_permissions(self, user=None):
+ if user is None and hasattr(g, "user"):
+ user = g.user
+
+ p = self.check_class_permissions(user=user)
+ if user is not None:
+ if owner_id_field is not None:
+ if user.id == getattr(self, owner_id_field): p.add("owner")
+ return p
return HasUserMixin
class SAModelViewBase(object):
@@ -148,6 +169,29 @@ def fetch_object(self, *args, **kwargs):
g.etagger.set_object(objs)
return objs
+class QueryFiltering(object):
+ """
+ Mixin class, adding support for filtering using query string.
+ Append this class from the left (i.e. `class Foo(QueryFiltering, ...)` to hook in.
+
+ Multiple filters for a same name are joined together by AND logic, exactly as
+ passing multiple filters to SQLAlchemy `filter` method.
+
+ Note, filtering is allowed only on class-level readable fields, as returned
+ by `check_class_permissions`. Other query arguments are silently ignored.
+ """
+ def get_query(self):
+ q = super(QueryFiltering, self).get_query()
+ columns = self.model.get_columns(only_permitted="readable")
+ columns = dict([(c.name, c) for c in columns])
+
+ for name, values in request.args.lists():
+ if name in columns:
+ c = getattr(self.model, name)
+ f = [c == json.loads(value) for value in values]
+ q = q.filter(*f)
+ return q
+
class PaginableByNumber(object):
"""
Mixin class, adding support for pagination by item number. Append this class
View
@@ -1,4 +1,5 @@
import string
+from functools import partial
def is_printable(value):
"""
@@ -7,3 +8,28 @@ def is_printable(value):
"""
return isinstance(value, basestring) \
and all(c in string.printable for c in value)
+
+# Taken from http://www.daniweb.com/software-development/python/code/406393/
+class mixedmethod(object):
+ """
+ This decorator mutates a function defined in a class into a 'mixed' class and instance method.
+
+ Usage:
+
+ class Spam:
+ @mixedmethod
+ def egg(self, cls, *args, **kwargs):
+ if self is None:
+ pass # executed if egg was called as a class method (eg. Spam.egg())
+ else:
+ pass # executed if egg was called as an instance method (eg. instance.egg())
+
+ The decorated methods need 2 implicit arguments: self and cls, the former being None when
+ there is no instance in the call. This follows the same rule as __get__ methods in python's
+ descriptor protocol.
+ """
+ def __init__(self, func):
+ self.func = func
+
+ def __get__(self, instance, cls):
+ return partial(self.func, instance, cls)
View
@@ -206,8 +206,8 @@ def get(self, *args, **kwargs):
obj = self.get_object(*args, **kwargs)
headers = {}
- if hasattr(obj, "check_permissions"):
- access = obj.check_permissions()
+ if hasattr(obj, "check_instance_permissions"):
+ access = obj.check_instance_permissions()
if len(access) > 0 and access != frozenset(["system"]):
headers["X-Access-Classes"] = ", ".join(sorted(access))
@@ -228,8 +228,8 @@ def get_columns(self, *args, **kwargs):
def patch(self, *args, **kwargs):
obj = self.get_object(**kwargs)
- if hasattr(obj, "check_permissions"):
- access = obj.check_permissions()
+ if hasattr(obj, "check_instance_permissions"):
+ access = obj.check_instance_permissions()
columns = dict([(c.name, c.permissions.get("writeable", set()))
for c in self.get_columns(only_db_columns=True)])
else:
View
@@ -24,9 +24,6 @@ def __init__(self, data):
def as_dict(self):
return self.data
- #def check_permissions(self):
- # return frozenset([])
-
def get_columns(self, **kwargs):
return set([ModelColumnInfo(self, "spam"),
ModelColumnInfo(self, "eggs")])
View
@@ -1,13 +1,13 @@
import unittest
-from flask.ext.toybox.sqlalchemy import SAModelMixin, SAModelView, SACollectionView, PaginableByNumber
+from flask.ext.toybox.sqlalchemy import SAModelMixin, SAModelView, SACollectionView, PaginableByNumber, QueryFiltering
from flask.ext.toybox.permissions import make_I
from flask.ext.toybox import ToyBox
from flask import Flask, g, request
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session, Session
from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy import Column, Integer, String
+from sqlalchemy import Column, Integer, String, Boolean
import json
Base = declarative_base()
@@ -19,14 +19,19 @@ class User(Base, SAModelMixin):
id = Column(Integer, primary_key=True)
username = Column(String, info=I("r:all,w:none"))
fullname = Column(String, info=I("rw:all"))
- email = Column(String, info=I("rw:owner"))
+ email = Column(String, info=I("rw:owner+"))
+ badges = Column(Integer, default=0, info=I("r:all,w:staff+"))
+ is_active = Column(Boolean, default=True, info=I("r:all,w:staff+"))
+ is_staff = Column(Boolean, default=False, info=I("r:staff+,w:admin+"))
- def __init__(self, username, fullname, email):
+ def __init__(self, username, fullname, email, **kwargs):
self.username = username
self.fullname = fullname
self.email = email
+ for name, value in kwargs.items():
+ setattr(self, name, value)
- def check_permissions(self, user=None):
+ def check_instance_permissions(self, user=None):
# Very silly a12n.
auth = request.args.get("auth", "")
if auth != "":
@@ -48,9 +53,9 @@ def setUp(self):
# Create some models
db_session = ScopedSession()
- db_session.add(User("spam", "Spam", "spam@users.example.org"))
- db_session.add(User("ham", "Ham", "ham@users.example.org"))
- db_session.add(User("eggs", "Eggs", "eggs@users.example.org"))
+ db_session.add(User("spam", "Spam", "spam@users.example.org", badges=1, is_staff=True))
+ db_session.add(User("ham", "Ham", "ham@users.example.org", is_active=False))
+ db_session.add(User("eggs", "Eggs", "eggs@users.example.org", badges=2, is_staff=True))
db_session.commit()
self.db_session = db_session
@@ -73,7 +78,7 @@ def save_object(self, obj):
db_session.commit()
app.add_url_rule("/users/<username>", view_func=UserView.as_view("user"))
- class UsersView(PaginableByNumber, SACollectionView):
+ class UsersView(PaginableByNumber, QueryFiltering, SACollectionView):
model = User
query_class = db_session.query
order_by = "username"
@@ -129,6 +134,14 @@ def test_collection_pagination(self):
usernames = [data_item.get("username", None) for data_item in data]
self.assertEqual(usernames, ["ham", "spam"])
+ def test_collection_filtering(self):
+ # This also tests whenever is_admin will be ignored, as it is not readable.
+ response = self.app.get("/users/?is_staff=true&is_admin=true&spam=spam", headers={"Accept": "application/json"})
+ self.assertEqual(response.status_code, 200, response.status)
+ data = json.loads(response.data)
+ usernames = set([data_item.get("username", None) for data_item in data])
+ self.assertEqual(usernames, set(["spam", "eggs"]))
+
def test_collection_is_readonly(self):
for method in ("put", "patch", "delete"):
response = getattr(self.app, method)("/users/", headers={"Accept": "application/json"})

0 comments on commit 4f0861e

Please sign in to comment.