Skip to content

Commit

Permalink
All tests now pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
hades committed Feb 10, 2011
1 parent 3896f6b commit b1c0d46
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 45 deletions.
17 changes: 8 additions & 9 deletions oauth2/__init__.py
Expand Up @@ -26,7 +26,6 @@
import urllib.request, urllib.parse, urllib.error
import time
import random
import urllib.parse
import hmac
import binascii
import httplib2
Expand Down Expand Up @@ -98,7 +97,7 @@ def to_unicode(s):
""" Convert to unicode, raise exception with instructive error
message if s is not unicode, ascii, or utf-8. """
if not isinstance(s, str):
if not isinstance(s, str):
if not isinstance(s, bytes):
raise TypeError('You are required to pass either unicode or string here, not: %r (%s)' % (type(s), s))
try:
s = s.decode('utf-8')
Expand All @@ -110,13 +109,13 @@ def to_utf8(s):
return to_unicode(s).encode('utf-8')

def to_unicode_if_string(s):
if isinstance(s, str):
if isinstance(s, str) or isinstance(s, bytes):
return to_unicode(s)
else:
return s

def to_utf8_if_string(s):
if isinstance(s, str):
if isinstance(s, str) or isinstance(s, bytes):
return to_utf8(s)
else:
return s
Expand All @@ -126,7 +125,7 @@ def to_unicode_optional_iterator(x):
Raise TypeError if x is a str containing non-utf8 bytes or if x is
an iterable which contains such a str.
"""
if isinstance(x, str):
if isinstance(x, str) or isinstance(x, bytes):
return to_unicode(x)

try:
Expand Down Expand Up @@ -436,7 +435,7 @@ def get_normalized_parameters(self):
continue
# 1.0a/9.1.1 states that kvp must be sorted by key, then by value,
# so we unpack sequence values into multiple items for sorting.
if isinstance(value, str):
if isinstance(value, str) or isinstance(value, bytes):
items.append((to_utf8_if_string(key), to_utf8(value)))
else:
try:
Expand Down Expand Up @@ -471,7 +470,7 @@ def sign_request(self, signature_method, consumer, token):
# section 4.1.1 "OAuth Consumers MUST NOT include an
# oauth_body_hash parameter on requests with form-encoded
# request bodies."
self['oauth_body_hash'] = base64.b64encode(sha(self.body).digest())
self['oauth_body_hash'] = base64.b64encode(sha(self.body.encode('utf8')).digest())

if 'oauth_consumer_key' not in self:
self['oauth_consumer_key'] = consumer.key
Expand Down Expand Up @@ -587,7 +586,7 @@ def _split_header(header):
@staticmethod
def _split_url_string(param_str):
"""Turn URL string into parameters."""
parameters = parse_qs(param_str.encode('utf-8'), keep_blank_values=True)
parameters = parse_qs(param_str, keep_blank_values=True)
for k, v in parameters.items():
parameters[k] = urllib.parse.unquote(v[0])
return parameters
Expand Down Expand Up @@ -818,7 +817,7 @@ def sign(self, request, consumer, token):
"""Builds the base signature string."""
key, raw = self.signing_base(request, consumer, token)

hashed = hmac.new(key, raw, sha)
hashed = hmac.new(key.encode('utf-8'), raw.encode('utf-8'), sha)

# Calculate the digest base 64.
return binascii.b2a_base64(hashed.digest())[:-1]
Expand Down
71 changes: 35 additions & 36 deletions tests/test_oauth.py
Expand Up @@ -31,7 +31,6 @@
import time
import urllib.request, urllib.parse, urllib.error
import urllib.parse
from types import ListType
import mock
import httplib2

Expand Down Expand Up @@ -257,14 +256,14 @@ def failUnlessReallyEqual(self, a, b, msg=None):

class TestFuncs(unittest.TestCase):
def test_to_unicode(self):
self.failUnlessRaises(TypeError, oauth.to_unicode, '\xae')
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, '\xae')
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, ['\xae'])
self.failUnlessRaises(TypeError, oauth.to_unicode, b'\xae')
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, b'\xae')
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, [b'\xae'])

self.failUnlessEqual(oauth.to_unicode(':-)'), ':-)')
self.failUnlessEqual(oauth.to_unicode(b':-)'), ':-)')
self.failUnlessEqual(oauth.to_unicode('\u00ae'), '\u00ae')
self.failUnlessEqual(oauth.to_unicode('\xc2\xae'), '\u00ae')
self.failUnlessEqual(oauth.to_unicode_optional_iterator([':-)']), [':-)'])
self.failUnlessEqual(oauth.to_unicode(b'\xc2\xae'), '\u00ae')
self.failUnlessEqual(oauth.to_unicode_optional_iterator([b':-)']), [':-)'])
self.failUnlessEqual(oauth.to_unicode_optional_iterator(['\u00ae']), ['\u00ae'])

class TestRequest(unittest.TestCase, ReallyEqualMixin):
Expand Down Expand Up @@ -490,25 +489,25 @@ def test_signature_base_string_nonascii_nonutf8(self):
req = oauth.Request("GET", url)
self.failUnlessReallyEqual(req.normalized_url, 'http://api.simplegeo.com/1.0/places/address.json')
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'WhufgeZKyYpKsI70GZaiDaYwl6g=')
self.failUnlessReallyEqual(req['oauth_signature'], b'WhufgeZKyYpKsI70GZaiDaYwl6g=')

url = 'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St,+San+Francisc\xe2\x9d\xa6,+CA'
url = b'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St,+San+Francisc\xe2\x9d\xa6,+CA'
req = oauth.Request("GET", url)
self.failUnlessReallyEqual(req.normalized_url, 'http://api.simplegeo.com/1.0/places/address.json')
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'WhufgeZKyYpKsI70GZaiDaYwl6g=')
self.failUnlessReallyEqual(req['oauth_signature'], b'WhufgeZKyYpKsI70GZaiDaYwl6g=')

url = 'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St,+San+Francisc%E2%9D%A6,+CA'
req = oauth.Request("GET", url)
self.failUnlessReallyEqual(req.normalized_url, 'http://api.simplegeo.com/1.0/places/address.json')
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'WhufgeZKyYpKsI70GZaiDaYwl6g=')
self.failUnlessReallyEqual(req['oauth_signature'], b'WhufgeZKyYpKsI70GZaiDaYwl6g=')

url = 'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St,+San+Francisc%E2%9D%A6,+CA'
req = oauth.Request("GET", url)
self.failUnlessReallyEqual(req.normalized_url, 'http://api.simplegeo.com/1.0/places/address.json')
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'WhufgeZKyYpKsI70GZaiDaYwl6g=')
self.failUnlessReallyEqual(req['oauth_signature'], b'WhufgeZKyYpKsI70GZaiDaYwl6g=')

def test_signature_base_string_with_query(self):
url = "https://www.google.com/m8/feeds/contacts/default/full/?alt=json&max-contacts=10"
Expand Down Expand Up @@ -598,9 +597,9 @@ def test_get_normalized_parameters(self):
'oauth_consumer_key': "0685bd9184jfhq22",
'oauth_signature_method': "HMAC-SHA1",
'oauth_token': "ad180jjd733klru7",
'multi': ['FOO','BAR', '\u00ae', '\xc2\xae'],
'multi': ['FOO','BAR', '\u00ae', b'\xc2\xae'],
'multi_same': ['FOO','FOO'],
'uni_utf8_bytes': '\xc2\xae',
'uni_utf8_bytes': b'\xc2\xae',
'uni_unicode_object': '\u00ae'
}

Expand Down Expand Up @@ -682,46 +681,46 @@ def test_request_nonutf8_bytes(self, mock_make_nonce, mock_make_timestamp):

# If someone passes a sequence of bytes which is not ascii for
# url, we'll raise an exception as early as possible.
url = "http://sp.example.com/\x92" # It's actually cp1252-encoding...
url = b"http://sp.example.com/\x92" # It's actually cp1252-encoding...
self.assertRaises(TypeError, oauth.Request, method="GET", url=url, parameters=params)

# And if they pass an unicode, then we'll use it.
url = 'http://sp.example.com/\u2019'
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'cMzvCkhvLL57+sTIxLITTHfkqZk=')
self.failUnlessReallyEqual(req['oauth_signature'], b'cMzvCkhvLL57+sTIxLITTHfkqZk=')

# And if it is a utf-8-encoded-then-percent-encoded non-ascii
# thing, we'll decode it and use it.
url = "http://sp.example.com/%E2%80%99"
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'yMLKOyNKC/DkyhUOb8DLSvceEWE=')
self.failUnlessReallyEqual(req['oauth_signature'], b'yMLKOyNKC/DkyhUOb8DLSvceEWE=')

# Same thing with the params.
url = "http://sp.example.com/"

# If someone passes a sequence of bytes which is not ascii in
# params, we'll raise an exception as early as possible.
params['non_oauth_thing'] = '\xae', # It's actually cp1252-encoding...
params['non_oauth_thing'] = b'\xae', # It's actually cp1252-encoding...
self.assertRaises(TypeError, oauth.Request, method="GET", url=url, parameters=params)

# And if they pass a unicode, then we'll use it.
params['non_oauth_thing'] = '\u2019'
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_signature'], '0GU50m0v60CVDB5JnoBXnvvvKx4=')
self.failUnlessReallyEqual(req['oauth_signature'], b'0GU50m0v60CVDB5JnoBXnvvvKx4=')

# And if it is a utf-8-encoded non-ascii thing, we'll decode
# it and use it.
params['non_oauth_thing'] = '\xc2\xae'
params['non_oauth_thing'] = b'\xc2\xae'
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_signature'], 'pqOCu4qvRTiGiXB8Z61Jsey0pMM=')
self.failUnlessReallyEqual(req['oauth_signature'], b'pqOCu4qvRTiGiXB8Z61Jsey0pMM=')


# Also if there are non-utf8 bytes in the query args.
url = "http://sp.example.com/?q=\x92" # cp1252
url = b"http://sp.example.com/?q=\x92" # cp1252
self.assertRaises(TypeError, oauth.Request, method="GET", url=url, parameters=params)

def test_request_hash_of_body(self):
Expand All @@ -743,8 +742,8 @@ def test_request_hash_of_body(self):
url = "http://www.example.com/resource"
req = oauth.Request(method="PUT", url=url, parameters=params, body="Hello World!", is_form_encoded=False)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_body_hash'], 'Lve95gjOVATpfV8EL5X4nxwjKHE=')
self.failUnlessReallyEqual(req['oauth_signature'], 't+MX8l/0S8hdbVQL99nD0X1fPnM=')
self.failUnlessReallyEqual(req['oauth_body_hash'], b'Lve95gjOVATpfV8EL5X4nxwjKHE=')
self.failUnlessReallyEqual(req['oauth_signature'], b't+MX8l/0S8hdbVQL99nD0X1fPnM=')
# oauth-bodyhash.html A.1 has
# '08bUFF%2Fjmp59mWB7cSgCYBUpJ0U%3D', but I don't see how that
# is possible.
Expand All @@ -760,8 +759,8 @@ def test_request_hash_of_body(self):

req = oauth.Request(method="PUT", url=url, parameters=params, body="Hello World!", is_form_encoded=False)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_body_hash'], 'Lve95gjOVATpfV8EL5X4nxwjKHE=')
self.failUnlessReallyEqual(req['oauth_signature'], 'CTFmrqJIGT7NsWJ42OrujahTtTc=')
self.failUnlessReallyEqual(req['oauth_body_hash'], b'Lve95gjOVATpfV8EL5X4nxwjKHE=')
self.failUnlessReallyEqual(req['oauth_signature'], b'CTFmrqJIGT7NsWJ42OrujahTtTc=')

# Appendix A.2
params = {
Expand All @@ -774,8 +773,8 @@ def test_request_hash_of_body(self):

req = oauth.Request(method="GET", url=url, parameters=params, is_form_encoded=False)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
self.failUnlessReallyEqual(req['oauth_body_hash'], '2jmj7l5rSw0yVb/vlWAYkK/YBwk=')
self.failUnlessReallyEqual(req['oauth_signature'], 'Zhl++aWSP0O3/hYQ0CuBc7jv38I=')
self.failUnlessReallyEqual(req['oauth_body_hash'], b'2jmj7l5rSw0yVb/vlWAYkK/YBwk=')
self.failUnlessReallyEqual(req['oauth_signature'], b'Zhl++aWSP0O3/hYQ0CuBc7jv38I=')


def test_sign_request(self):
Expand All @@ -795,7 +794,7 @@ def test_sign_request(self):
req = oauth.Request(method="GET", url=url, parameters=params)

methods = {
'DX01TdHws7OninCLK9VztNTH1M4=': oauth.SignatureMethod_HMAC_SHA1(),
b'DX01TdHws7OninCLK9VztNTH1M4=': oauth.SignatureMethod_HMAC_SHA1(),
'con-test-secret&tok-test-secret': oauth.SignatureMethod_PLAINTEXT()
}

Expand All @@ -805,26 +804,26 @@ def test_sign_request(self):
self.assertEquals(req['oauth_signature'], exp)

# Also if there are non-ascii chars in the URL.
url = "http://sp.example.com/\xe2\x80\x99" # utf-8 bytes
url = b"http://sp.example.com/\xe2\x80\x99" # utf-8 bytes
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, tok)
self.assertEquals(req['oauth_signature'], 'loFvp5xC7YbOgd9exIO6TxB7H4s=')
self.assertEquals(req['oauth_signature'], b'loFvp5xC7YbOgd9exIO6TxB7H4s=')

url = 'http://sp.example.com/\u2019' # Python unicode object
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, tok)
self.assertEquals(req['oauth_signature'], 'loFvp5xC7YbOgd9exIO6TxB7H4s=')
self.assertEquals(req['oauth_signature'], b'loFvp5xC7YbOgd9exIO6TxB7H4s=')

# Also if there are non-ascii chars in the query args.
url = "http://sp.example.com/?q=\xe2\x80\x99" # utf-8 bytes
url = b"http://sp.example.com/?q=\xe2\x80\x99" # utf-8 bytes
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, tok)
self.assertEquals(req['oauth_signature'], 'IBw5mfvoCsDjgpcsVKbyvsDqQaU=')
self.assertEquals(req['oauth_signature'], b'IBw5mfvoCsDjgpcsVKbyvsDqQaU=')

url = 'http://sp.example.com/?q=\u2019' # Python unicode object
req = oauth.Request(method="GET", url=url, parameters=params)
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, tok)
self.assertEquals(req['oauth_signature'], 'IBw5mfvoCsDjgpcsVKbyvsDqQaU=')
self.assertEquals(req['oauth_signature'], b'IBw5mfvoCsDjgpcsVKbyvsDqQaU=')

def test_from_request(self):
url = "http://sp.example.com/"
Expand Down Expand Up @@ -1189,7 +1188,7 @@ def test_access_token_post(self):

self.assertEquals(int(resp['status']), 200)

res = dict(parse_qsl(content))
res = dict(parse_qsl(content.decode('utf-8')))
self.assertTrue('oauth_token' in res)
self.assertTrue('oauth_token_secret' in res)

Expand Down

0 comments on commit b1c0d46

Please sign in to comment.