Skip to content

Commit

Permalink
Unicode decoding crash'n'burn style, default off, #53, #68.
Browse files Browse the repository at this point in the history
  • Loading branch information
ib-lundgren committed Nov 19, 2012
1 parent 20e71a4 commit 8c65b1e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
18 changes: 17 additions & 1 deletion oauthlib/common.py
Expand Up @@ -236,6 +236,7 @@ def safe_string_equals(a, b):
result |= ord(x) ^ ord(y)
return result == 0


class Request(object):
"""A malleable representation of a signable HTTP request.

Expand All @@ -250,7 +251,22 @@ class Request(object):
unmolested.
"""

def __init__(self, uri, http_method='GET', body=None, headers=None):
def __init__(self, uri, http_method='GET', body=None, headers=None,
convert_to_unicode=False, encoding='utf-8'):
if convert_to_unicode:
if isinstance(uri, bytes_type):
uri = uri.decode(encoding)
if isinstance(http_method, bytes_type):
http_method = http_method.decode(encoding)
if isinstance(body, bytes_type):
body = body.decode(encoding)
unicode_headers = {}
for k, v in headers.items():
k = k.decode(encoding) if isinstance(k, bytes_type) else k
v = v.decode(encoding) if isinstance(v, bytes_type) else v
unicode_headers[k] = v
headers = unicode_headers

self.uri = uri
self.http_method = http_method
self.headers = headers or {}
Expand Down
37 changes: 35 additions & 2 deletions oauthlib/oauth1/rfc5849/__init__.py
Expand Up @@ -10,12 +10,18 @@
"""

import logging
import sys
import time
try:
import urlparse
except ImportError:
import urllib.parse as urlparse

if sys.version_info[0] == 3:
bytes_type = bytes
else:
bytes_type = str

from oauthlib.common import Request, urlencode, generate_nonce
from oauthlib.common import generate_timestamp
from . import parameters, signature, utils
Expand All @@ -41,7 +47,30 @@ def __init__(self, client_key,
callback_uri=None,
signature_method=SIGNATURE_HMAC,
signature_type=SIGNATURE_TYPE_AUTH_HEADER,
rsa_key=None, verifier=None, realm=None):
rsa_key=None, verifier=None, realm=None,
convert_to_unicode=False, encoding='utf-8'):
if convert_to_unicode:
if isinstance(client_key, bytes_type):
client_key = client_key.decode(encoding)
if isinstance(client_secret, bytes_type):
client_secret = client_secret.decode(encoding)
if isinstance(resource_owner, bytes_type):
resource_owner = resource_owner.decode(encoding)
if isinstance(resource_owner_secret, bytes_type):
resource_owner_secret = resource_owner_secret.decode(encoding)
if isinstance(callback_uri, bytes_type):
callback_uri = callback_uri.decode(encoding)
if isinstance(signature_method, bytes_type):
signature_method = signature_method.decode(encoding)
if isinstance(signature_type, bytes_type):
signature_type = signature_type.decode(encoding)
if isinstance(rsa_key, bytes_type):
rsa_key = rsa_key.decode(encoding)
if isinstance(verifier, bytes_type):
verifier = verifier.decode(encoding)
if isinstance(realm, bytes_type):
realm = realm.decode(encoding)

self.client_key = client_key
self.client_secret = client_secret
self.resource_owner_key = resource_owner_key
Expand All @@ -52,6 +81,8 @@ def __init__(self, client_key,
self.rsa_key = rsa_key
self.verifier = verifier
self.realm = realm
self.convert_to_unicode = convert_to_unicode
self.encoding = encoding

if self.signature_method == SIGNATURE_RSA and self.rsa_key is None:
raise ValueError('rsa_key is required when using RSA signature method.')
Expand Down Expand Up @@ -172,7 +203,9 @@ def sign(self, uri, http_method='GET', body=None, headers=None, realm=None):
dicts, for example.
"""
# normalize request data
request = Request(uri, http_method, body, headers)
request = Request(uri, http_method, body, headers,
convert_to_unicode=self.convert_to_unicode,
encoding=self.encoding)

# sanity check
content_type = request.headers.get('Content-Type', None)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_common.py
@@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
import sys
from oauthlib.common import *
from .unittest import TestCase

if sys.version_info[0] == 3:
bytes_type = bytes
else:
bytes_type = lambda s, e: str(s)

class CommonTests(TestCase):
params_dict = {'foo': 'bar', 'baz': '123', }
Expand Down Expand Up @@ -47,6 +52,18 @@ def test_extract_non_formencoded_string(self):
def test_extract_invalid(self):
self.assertEqual(extract_params(object()), None)

def test_non_unicode_params(self):
r = Request(bytes_type('http://a.b/path?query', 'utf-8'),
http_method=bytes_type('GET', 'utf-8'),
body=bytes_type('you=shall+pass', 'utf-8'),
headers={bytes_type('a', 'utf-8'): bytes_type('b', 'utf-8')},
convert_to_unicode=True)
self.assertEqual(r.uri, 'http://a.b/path?query')
self.assertEqual(r.http_method, 'GET')
self.assertEqual(r.body, 'you=shall+pass')
self.assertEqual(r.decoded_body, [('you', 'shall pass')])
self.assertEqual(r.headers, {'a': 'b'})

def test_none_body(self):
r = Request(self.uri)
self.assertEqual(r.decoded_body, None)
Expand Down

0 comments on commit 8c65b1e

Please sign in to comment.