Permalink
Browse files

Adding a better way of building auth decorators, then converting the

login_required to use the new stuff... plus an "admin_required"
decorator
  • Loading branch information...
1 parent 56d7ee6 commit 253eaa2aa9477ab9b8d070d71baf8992500e45e8 @coleifer committed May 27, 2012
Showing with 43 additions and 10 deletions.
  1. +17 −10 flask_peewee/auth.py
  2. +21 −0 flask_peewee/tests/auth.py
  3. +5 −0 flask_peewee/tests/test_app.py
View
@@ -99,17 +99,24 @@ def get_urls(self):
def get_login_form(self):
return LoginForm
- def login_required(self, func):
- @functools.wraps(func)
- def inner(*args, **kwargs):
- user = self.get_logged_in_user()
-
- if not user:
- login_url = url_for('%s.login' % self.blueprint.name, next=get_next())
- return redirect(login_url)
+ def test_user(self, test_fn):
+ def decorator(fn):
+ @functools.wraps(fn)
+ def inner(*args, **kwargs):
+ user = self.get_logged_in_user()
+
+ if not user or not test_fn(user):
+ login_url = url_for('%s.login' % self.blueprint.name, next=get_next())
+ return redirect(login_url)
+ return fn(*args, **kwargs)
+ return inner
+ return decorator
- return func(*args, **kwargs)
- return inner
+ def login_required(self, func):
+ return self.test_user(lambda u: True)(func)
+
+ def admin_required(self, func):
+ return self.test_user(lambda u: u.admin)(func)
def authenticate(self, username, password):
active = self.User.select().where(active=True)
View
@@ -162,3 +162,24 @@ def test_login_required(self):
self.assertEqual(resp.status_code, 200)
self.assertEqual(auth.get_logged_in_user(), self.admin)
+
+ def test_admin_required(self):
+ self.create_users()
+
+ with self.flask_app.test_client() as c:
+ resp = c.get('/secret/')
+ self.assertEqual(resp.status_code, 302)
+ self.assertTrue(resp.headers['location'].endswith('/accounts/login/?next=%2Fsecret%2F'))
+
+ self.login('normal', 'normal', c)
+
+ resp = c.get('/secret/')
+ self.assertEqual(resp.status_code, 302)
+ self.assertTrue(resp.headers['location'].endswith('/accounts/login/?next=%2Fsecret%2F'))
+ self.assertEqual(auth.get_logged_in_user(), self.normal)
+
+ self.login('admin', 'admin', c)
+ resp = c.get('/secret/')
+ self.assertEqual(resp.status_code, 200)
+
+ self.assertEqual(auth.get_logged_in_user(), self.admin)
@@ -211,6 +211,11 @@ def homepage():
def private_timeline():
return Response()
+@app.route('/secret/')
+@auth.admin_required
+def secret_area():
+ return Response()
+
admin.setup()
api.setup()

0 comments on commit 253eaa2

Please sign in to comment.