diff --git a/docs/customization.rst b/docs/customization.rst index db7ce30..d3e24d5 100644 --- a/docs/customization.rst +++ b/docs/customization.rst @@ -12,7 +12,7 @@ checks before logging them. The easier way to do see is by overriding the login view. The default view is defined like this:: - @view_config(route_name='login') + @view_config(route_name='login', check_csrf=True) def login(request): email = verify_login(request) headers = remember(request, email) @@ -25,7 +25,7 @@ which the button was clicked ; by default we redirect the user back there after So, if you want to check that an email is on a whitelist and create a profile and redirect new users, you can define a new login view like this one:: - @view_config(route_name='login') + @view_config(route_name='login', check_csrf=True) def login(request): email = verify_login('email') if email not in whitelist: @@ -39,9 +39,8 @@ redirect new users, you can define a new login view like this one:: Some goes if you want to do extra stuff at logout. The default logout view looks like this:: - @view_config(route_name='logout') + @view_config(route_name='logout', check_csrf=True) def logout(request): - check_csrf_token(request) headers = forget(request) return HTTPFound(request.POST['came_from'], headers=headers) diff --git a/pyramid_persona/__init__.py b/pyramid_persona/__init__.py index ad3dace..dac5aca 100644 --- a/pyramid_persona/__init__.py +++ b/pyramid_persona/__init__.py @@ -75,19 +75,19 @@ def check(): config.registry['persona.login_route'] = login_route login_path = settings.get('persona.login_path', '/login') config.add_route(login_route, login_path) - config.add_view(login, route_name=login_route) + config.add_view(login, route_name=login_route, check_csrf=True) logout_route = settings.get('persona.logout_route', 'logout') config.registry['persona.logout_route'] = logout_route logout_path = settings.get('persona.logout_path', '/logout') config.add_route(logout_route, logout_path) - config.add_view(logout, route_name=logout_route) + config.add_view(logout, route_name=logout_route, check_csrf=True) # A simple 403 view, with a login button. config.add_forbidden_view(forbidden) # A quick access to the login button - config.set_request_property(button, 'persona_button', reify=True) + config.add_request_method(button, 'persona_button', reify=True) # The javascript needed by persona - config.set_request_property(js, 'persona_js', reify=True) + config.add_request_method(js, 'persona_js', reify=True) diff --git a/pyramid_persona/tests.py b/pyramid_persona/tests.py index b1aca7e..0c7ae5e 100644 --- a/pyramid_persona/tests.py +++ b/pyramid_persona/tests.py @@ -1,30 +1,16 @@ import unittest -from pyramid.interfaces import IAuthorizationPolicy, IAuthenticationPolicy -from pyramid.testing import DummySecurityPolicy from pyramid.httpexceptions import HTTPBadRequest import requests from pyramid import testing -class SecurityPolicy(DummySecurityPolicy): - remembered = None - forgotten = None - def remember(self, request, principal, **kw): - self.remembered = principal - return [] - - def forget(self, request): - self.forgotten = True - return [] - - class ViewTests(unittest.TestCase): def setUp(self): self.config = testing.setUp(autocommit=False) self.config.add_settings({'persona.audiences': 'http://someaudience'}) self.config.include('pyramid_persona') - self.security_policy = SecurityPolicy() + self.security_policy = self.config.testing_securitypolicy() self.config.set_authorization_policy(self.security_policy) self.config.set_authentication_policy(self.security_policy) self.config.commit() @@ -59,7 +45,7 @@ def test_login_fails_with_bad_audience(self): request.params['came_from'] = '/' self.assertRaises(HTTPBadRequest, login, request) - self.assertEqual(self.security_policy.remembered, None) + self.assertFalse(hasattr(self.security_policy, 'remembered')) def test_logout(self): from .views import logout diff --git a/pyramid_persona/views.py b/pyramid_persona/views.py index 67ccef5..71bbf98 100644 --- a/pyramid_persona/views.py +++ b/pyramid_persona/views.py @@ -2,24 +2,14 @@ from pyramid.httpexceptions import HTTPBadRequest, HTTPFound from pyramid.response import Response from pyramid.security import remember, forget - import browserid.errors -def _check_csrf_token(request): - """Check the CSRF token from the request. Raises if invalid. - - Copied from pyramid.session.check_csrf_token in pyramid==1.4a2.""" - if request.params.get('csrf_token') != request.session.get_csrf_token(): - raise HTTPBadRequest('incorrect CSRF token') - - def verify_login(request): """Verifies the assertion and the csrf token in the given request. Returns the email of the user if everything is valid, otherwise raises a HTTPBadRequest""" - _check_csrf_token(request) verifier = request.registry['persona.verifier'] try: data = verifier.verify(request.POST['assertion']) @@ -37,7 +27,6 @@ def login(request): def logout(request): """View to forget the user""" - _check_csrf_token(request) headers = forget(request) return HTTPFound(request.POST['came_from'], headers=headers) diff --git a/setup.py b/setup.py index f425a9d..4148d29 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ CHANGES = open(os.path.join(here, 'CHANGES.rst')).read() requires = [ - 'pyramid', + 'pyramid>=1.4', 'PyBrowserID', ]