Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Unicode decoding crash'n'burn style, default off, #53, #68.

  • Loading branch information...
commit 8c65b1eae1e0586f6fb36423d0d9f03ec9115c10 1 parent 20e71a4
@ib-lundgren ib-lundgren authored
View
18 oauthlib/common.py
@@ -236,6 +236,7 @@ def safe_string_equals(a, b):
result |= ord(x) ^ ord(y)
return result == 0
+
class Request(object):
"""A malleable representation of a signable HTTP request.
@@ -250,7 +251,22 @@ class Request(object):
unmolested.
"""
- def __init__(self, uri, http_method='GET', body=None, headers=None):
+ def __init__(self, uri, http_method='GET', body=None, headers=None,
+ convert_to_unicode=False, encoding='utf-8'):
+ if convert_to_unicode:
+ if isinstance(uri, bytes_type):
+ uri = uri.decode(encoding)
+ if isinstance(http_method, bytes_type):
+ http_method = http_method.decode(encoding)
+ if isinstance(body, bytes_type):
+ body = body.decode(encoding)
+ unicode_headers = {}
+ for k, v in headers.items():
+ k = k.decode(encoding) if isinstance(k, bytes_type) else k
+ v = v.decode(encoding) if isinstance(v, bytes_type) else v
+ unicode_headers[k] = v
+ headers = unicode_headers
+
self.uri = uri
self.http_method = http_method
self.headers = headers or {}
View
37 oauthlib/oauth1/rfc5849/__init__.py
@@ -10,12 +10,18 @@
"""
import logging
+import sys
import time
try:
import urlparse
except ImportError:
import urllib.parse as urlparse
+if sys.version_info[0] == 3:
+ bytes_type = bytes
+else:
+ bytes_type = str
+
from oauthlib.common import Request, urlencode, generate_nonce
from oauthlib.common import generate_timestamp
from . import parameters, signature, utils
@@ -41,7 +47,30 @@ def __init__(self, client_key,
callback_uri=None,
signature_method=SIGNATURE_HMAC,
signature_type=SIGNATURE_TYPE_AUTH_HEADER,
- rsa_key=None, verifier=None, realm=None):
+ rsa_key=None, verifier=None, realm=None,
+ convert_to_unicode=False, encoding='utf-8'):
+ if convert_to_unicode:
+ if isinstance(client_key, bytes_type):
+ client_key = client_key.decode(encoding)
+ if isinstance(client_secret, bytes_type):
+ client_secret = client_secret.decode(encoding)
+ if isinstance(resource_owner, bytes_type):
+ resource_owner = resource_owner.decode(encoding)
+ if isinstance(resource_owner_secret, bytes_type):
+ resource_owner_secret = resource_owner_secret.decode(encoding)
+ if isinstance(callback_uri, bytes_type):
+ callback_uri = callback_uri.decode(encoding)
+ if isinstance(signature_method, bytes_type):
+ signature_method = signature_method.decode(encoding)
+ if isinstance(signature_type, bytes_type):
+ signature_type = signature_type.decode(encoding)
+ if isinstance(rsa_key, bytes_type):
+ rsa_key = rsa_key.decode(encoding)
+ if isinstance(verifier, bytes_type):
+ verifier = verifier.decode(encoding)
+ if isinstance(realm, bytes_type):
+ realm = realm.decode(encoding)
+
self.client_key = client_key
self.client_secret = client_secret
self.resource_owner_key = resource_owner_key
@@ -52,6 +81,8 @@ def __init__(self, client_key,
self.rsa_key = rsa_key
self.verifier = verifier
self.realm = realm
+ self.convert_to_unicode = convert_to_unicode
+ self.encoding = encoding
if self.signature_method == SIGNATURE_RSA and self.rsa_key is None:
raise ValueError('rsa_key is required when using RSA signature method.')
@@ -172,7 +203,9 @@ def sign(self, uri, http_method='GET', body=None, headers=None, realm=None):
dicts, for example.
"""
# normalize request data
- request = Request(uri, http_method, body, headers)
+ request = Request(uri, http_method, body, headers,
+ convert_to_unicode=self.convert_to_unicode,
+ encoding=self.encoding)
# sanity check
content_type = request.headers.get('Content-Type', None)
View
17 tests/test_common.py
@@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import sys
from oauthlib.common import *
from .unittest import TestCase
+if sys.version_info[0] == 3:
+ bytes_type = bytes
+else:
+ bytes_type = lambda s, e: str(s)
class CommonTests(TestCase):
params_dict = {'foo': 'bar', 'baz': '123', }
@@ -47,6 +52,18 @@ def test_extract_non_formencoded_string(self):
def test_extract_invalid(self):
self.assertEqual(extract_params(object()), None)
+ def test_non_unicode_params(self):
+ r = Request(bytes_type('http://a.b/path?query', 'utf-8'),
+ http_method=bytes_type('GET', 'utf-8'),
+ body=bytes_type('you=shall+pass', 'utf-8'),
+ headers={bytes_type('a', 'utf-8'): bytes_type('b', 'utf-8')},
+ convert_to_unicode=True)
+ self.assertEqual(r.uri, 'http://a.b/path?query')
+ self.assertEqual(r.http_method, 'GET')
+ self.assertEqual(r.body, 'you=shall+pass')
+ self.assertEqual(r.decoded_body, [('you', 'shall pass')])
+ self.assertEqual(r.headers, {'a': 'b'})
+
def test_none_body(self):
r = Request(self.uri)
self.assertEqual(r.decoded_body, None)
Please sign in to comment.
Something went wrong with that request. Please try again.