diff --git a/flask_oauthlib/utils.py b/flask_oauthlib/utils.py index 23cc14d7..f5921309 100644 --- a/flask_oauthlib/utils.py +++ b/flask_oauthlib/utils.py @@ -5,9 +5,22 @@ from oauthlib.common import to_unicode, bytes_type +def _get_uri_from_request(request): + """ + The uri returned from request.uri is not properly urlencoded + (sometimes it's partially urldecoded) This is a weird hack to get + werkzeug to return the proper urlencoded string uri + """ + uri = request.base_url + if request.query_string: + uri += '?' + request.query_string.decode('utf-8') + return uri + + def extract_params(): """Extract request params.""" - uri = request.url + + uri = _get_uri_from_request(request) http_method = request.method headers = dict(request.headers) if 'wsgi.input' in headers: diff --git a/tests/oauth1/test_oauth1.py b/tests/oauth1/test_oauth1.py index d2ab0bf8..ed7d0fe0 100644 --- a/tests/oauth1/test_oauth1.py +++ b/tests/oauth1/test_oauth1.py @@ -95,12 +95,6 @@ def test_invalid_request_token(self): }) assert 'error' in rv.location - def test_invalid_urlencoded(self): - rv = self.client.get('/oauth/request_token?query=tam%20q') - assert b'non+urlencoded' in rv.data - rv = self.client.get('/oauth/access_token?query=tam%20q') - assert b'non+urlencoded' in rv.data - auth_header = ( u'OAuth realm="%(realm)s",' diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..68299ebd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,45 @@ +import unittest +import wsgiref.util +from contextlib import contextmanager +import mock +import werkzeug.wrappers +from flask_oauthlib.utils import extract_params +from oauthlib.common import Request +from flask import request + + +@contextmanager +def set_flask_request(wsgi_environ): + """ + Test helper context manager that mocks the flask request global I didn't + need the whole request context just to test the functions in helpers and I + wanted to be able to set the raw WSGI environment + """ + environ = {} + environ.update(wsgi_environ) + wsgiref.util.setup_testing_defaults(environ) + r = werkzeug.wrappers.Request(environ) + + with mock.patch.dict(extract_params.__globals__, {'request': r}): + yield + + +class UtilsTestSuite(unittest.TestCase): + + def test_extract_params(self): + with set_flask_request({'QUERY_STRING': 'test=foo&foo=bar'}): + uri, http_method, body, headers = extract_params() + self.assertEquals(uri, 'http://127.0.0.1/?test=foo&foo=bar') + self.assertEquals(http_method, 'GET') + self.assertEquals(body, {}) + self.assertEquals(headers, {'Host': '127.0.0.1'}) + + def test_extract_params_with_urlencoded_json(self): + wsgi_environ = { + 'QUERY_STRING': 'state=%7B%22t%22%3A%22a%22%2C%22i%22%3A%22l%22%7D' + } + with set_flask_request(wsgi_environ): + uri, http_method, body, headers = extract_params() + # Request constructor will try to urldecode the querystring, make + # sure this doesn't fail. + Request(uri, http_method, body, headers)