Skip to content

Commit

Permalink
Merge pull request #169 from nfvs/separate_csrf_protect
Browse files Browse the repository at this point in the history
Abstract _csrf_protect() into a separate method.
  • Loading branch information
lepture committed Feb 15, 2015
2 parents 16f8c9e + e85808b commit d0bb430
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 28 deletions.
67 changes: 39 additions & 28 deletions flask_wtf/csrf.py
Expand Up @@ -138,30 +138,16 @@ def __init__(self, app=None):
self.init_app(app)

def init_app(self, app):
self._app = app
app.jinja_env.globals['csrf_token'] = generate_csrf
app.config.setdefault(
'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token']
)
app.config.setdefault('WTF_CSRF_SSL_STRICT', True)
app.config.setdefault('WTF_CSRF_ENABLED', True)
app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True)
app.config.setdefault('WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH'])

def _get_csrf_token():
# find the ``csrf_token`` field in the subitted form
# if the form had a prefix, the name will be
# ``{prefix}-csrf_token``
for key in request.form:
if key.endswith('csrf_token'):
csrf_token = request.form[key]
if csrf_token:
return csrf_token

for header_name in app.config['WTF_CSRF_HEADERS']:
csrf_token = request.headers.get(header_name)
if csrf_token:
return csrf_token
return None

# expose csrf_token as a helper in all templates
@app.context_processor
def csrf_token():
Expand All @@ -173,6 +159,9 @@ def _csrf_protect():
if not app.config['WTF_CSRF_ENABLED']:
return

if not app.config['WTF_CSRF_CHECK_DEFAULT']:
return

if request.method not in app.config['WTF_CSRF_METHODS']:
return

Expand All @@ -190,21 +179,43 @@ def _csrf_protect():
if request.blueprint in self._exempt_blueprints:
return

if not validate_csrf(_get_csrf_token()):
reason = 'CSRF token missing or incorrect.'
return self._error_response(reason)
self.protect()

def _get_csrf_token(self):
# find the ``csrf_token`` field in the subitted form
# if the form had a prefix, the name will be
# ``{prefix}-csrf_token``
for key in request.form:
if key.endswith('csrf_token'):
csrf_token = request.form[key]
if csrf_token:
return csrf_token

for header_name in self._app.config['WTF_CSRF_HEADERS']:
csrf_token = request.headers.get(header_name)
if csrf_token:
return csrf_token
return None

def protect(self):
if request.method not in self._app.config['WTF_CSRF_METHODS']:
return

if request.is_secure and app.config['WTF_CSRF_SSL_STRICT']:
if not request.referrer:
reason = 'Referrer checking failed - no Referrer.'
return self._error_response(reason)
if not validate_csrf(self._get_csrf_token()):
reason = 'CSRF token missing or incorrect.'
return self._error_response(reason)

good_referrer = 'https://%s/' % request.host
if not same_origin(request.referrer, good_referrer):
reason = 'Referrer checking failed - origin not match.'
return self._error_response(reason)
if request.is_secure and self._app.config['WTF_CSRF_SSL_STRICT']:
if not request.referrer:
reason = 'Referrer checking failed - no Referrer.'
return self._error_response(reason)

good_referrer = 'https://%s/' % request.host
if not same_origin(request.referrer, good_referrer):
reason = 'Referrer checking failed - origin does not match.'
return self._error_response(reason)

request.csrf_valid = True # mark this request is csrf valid
request.csrf_valid = True # mark this request is csrf valid

def exempt(self, view):
"""A decorator that can exclude a view from csrf protection.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_csrf.py
Expand Up @@ -38,6 +38,12 @@ def csrf_exempt():
"index.html", form=form, name=name
)

@csrf.exempt
@app.route('/csrf-protect-method', methods=['GET', 'POST'])
def csrf_protect_method():
csrf.protect()
return 'protected'

bp = Blueprint('csrf', __name__)

@bp.route('/foo', methods=['GET', 'POST'])
Expand Down Expand Up @@ -170,6 +176,27 @@ def test_valid_secure_csrf(self):
)
assert response.status_code == 200

def test_valid_csrf_method(self):
response = self.client.get("/")
csrf_token = get_csrf_token(response.data)

response = self.client.post("/csrf-protect-method", data={
"csrf_token": csrf_token
})
assert response.status_code == 200

def test_invalid_csrf_method(self):
response = self.client.post("/csrf-protect-method", data={"name": "danny"})
assert response.status_code == 400

@self.csrf.error_handler
def invalid(reason):
return reason

response = self.client.post("/", data={"name": "danny"})
assert response.status_code == 200
assert b'token missing' in response.data

def test_empty_csrf_headers(self):
response = self.client.get("/", base_url='https://localhost/')
csrf_token = get_csrf_token(response.data)
Expand Down

0 comments on commit d0bb430

Please sign in to comment.