diff --git a/flask_seasurf.py b/flask_seasurf.py index cb10382..40b344b 100755 --- a/flask_seasurf.py +++ b/flask_seasurf.py @@ -115,6 +115,7 @@ def __init__(self, app=None): self._include_views = set() self._exempt_urls = tuple() self._disable_cookie = None + self._skip_validation = None if app is not None: self.init_app(app) @@ -217,6 +218,29 @@ def disable_cookie(response): self._disable_cookie = callback return callback + def skip_validation(self, callback): + ''' + A decorator to programmatically disable validating the CSRF token + cookie on the request. The function will be passed a Flask Request + object for the current request. + + The decorated function must return :class:`True` or :class:`False`. + + Example usage of :class:`skip_validation` might look something + like:: + + csrf = SeaSurf(app) + + @csrf.skip_validation + def skip_validation(request): + if is_api_request(): + return False + return True + ''' + + self._skip_validation = callback + return callback + def validate(self): ''' Validates a CSRF token for the current request. @@ -357,6 +381,9 @@ def _before_request(self): if not self._should_use_token(_app_ctx_stack.top._view_func): return + if self._skip_validation and self._skip_validation(request): + return + self.validate() def _after_request(self, response): diff --git a/test_seasurf.py b/test_seasurf.py index 9d7c6cb..c751dbc 100644 --- a/test_seasurf.py +++ b/test_seasurf.py @@ -453,6 +453,81 @@ def getCookie(self, response, cookie_name): return None +class SeaSurfTestCaseSkipValidation(unittest.TestCase): + def setUp(self): + app = Flask(__name__) + app.debug = True + app.secret_key = '1234' + + self.app = app + + csrf = SeaSurf() + csrf._csrf_disable = False + self.csrf = csrf + + # Initialize CSRF protection. + self.csrf.init_app(app) + + @self.csrf.skip_validation + def skip_validation(request): + if request.path == '/foo/quz': + return True + if request.path == '/manual': + return True + return False + + @app.route('/foo/baz', methods=['GET']) + def get_foobaz(): + return 'bar' + + @app.route('/foo/baz', methods=['DELETE']) + def foobaz(): + return 'bar' + + @app.route('/foo/quz', methods=['POST']) + def fooquz(): + return 'bar' + + @app.route('/manual', methods=['POST']) + def manual(): + csrf.validate() + return 'bar' + + def test_skips_validation(self): + with self.app.test_client() as c: + rv = c.post('/foo/quz') + self.assertIn(b('bar'), rv.data) + cookie = self.getCookie(rv, self.csrf._csrf_name) + token = self.csrf._get_token() + self.assertEqual(cookie, token) + + def test_enforces_validation_reject(self): + with self.app.test_client() as c: + rv = c.delete('/foo/baz') + self.assertIn(b('403 Forbidden'), rv.data) + + def test_enforces_validation_accept(self): + with self.app.test_client() as c: + # GET generates CSRF token + c.get('/foo/baz') + rv = c.delete('/foo/baz', + headers={'X-CSRFToken': self.csrf._get_token()}) + self.assertIn(b('bar'), rv.data) + + def test_manual_validation(self): + with self.app.test_client() as c: + rv = c.post('/manual') + self.assertIn(b('403 Forbidden'), rv.data) + + def getCookie(self, response, cookie_name): + cookies = response.headers.getlist('Set-Cookie') + for cookie in cookies: + key, value = list(parse_cookie(cookie).items())[0] + if key == cookie_name: + return value + return None + + class SeaSurfTestManualValidation(unittest.TestCase): def setUp(self): app = Flask(__name__) @@ -674,6 +749,7 @@ def suite(): suite.addTest(unittest.makeSuite(SeaSurfTestCaseIncludeViews)) suite.addTest(unittest.makeSuite(SeaSurfTestCaseExemptUrls)) suite.addTest(unittest.makeSuite(SeaSurfTestCaseDisableCookie)) + suite.addTest(unittest.makeSuite(SeaSurfTestCaseSkipValidation)) suite.addTest(unittest.makeSuite(SeaSurfTestCaseSave)) suite.addTest(unittest.makeSuite(SeaSurfTestCaseSetCookie)) suite.addTest(unittest.makeSuite(SeaSurfTestCaseReferer))