From 6d5256c7ba385400fe980c26b8b4acd4eb608a68 Mon Sep 17 00:00:00 2001 From: Joe Stump Date: Mon, 12 Oct 2009 17:45:39 -0600 Subject: [PATCH] Added a multitude of tests. --- oauth/__init__.py | 94 ++++++++-------- tests/test_oauth.py | 261 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 307 insertions(+), 48 deletions(-) diff --git a/oauth/__init__.py b/oauth/__init__.py index 3425011..a055f44 100644 --- a/oauth/__init__.py +++ b/oauth/__init__.py @@ -58,14 +58,6 @@ def escape(s): return urllib.quote(s, safe='~') -def _utf8_str(s): - """Convert unicode to utf-8.""" - if isinstance(s, unicode): - return s.encode("utf-8") - else: - return str(s) - - def generate_timestamp(): """Get seconds since epoch (UTC).""" return int(time.time()) @@ -280,7 +272,7 @@ def get_nonoauth_parameters(self): def to_header(self, realm=''): """Serialize as a header for an HTTPAuth request.""" - oauth_params = ((k, v) for k, v in self.iteritems() + oauth_params = ((k, v) for k, v in self.items() if k.startswith('oauth_')) stringy_params = ((k, escape(str(v))) for k, v in oauth_params) header_params = ('%s="%s"' % (k, v) for k, v in stringy_params) @@ -288,7 +280,7 @@ def to_header(self, realm=''): auth_header = 'OAuth realm="%s"' % realm if params_header: - auth_header += params_header + auth_header = "%s, %s" % (auth_header, params_header) return {'Authorization': auth_header} @@ -306,7 +298,7 @@ def get_normalized_parameters(self): return urllib.urlencode(sorted(items)) def sign_request(self, signature_method, consumer, token): - """Set the signature parameter to the result of build_signature.""" + """Set the signature parameter to the result of sign.""" self['oauth_signature_method'] = signature_method.name self['oauth_signature'] = signature_method.sign(self, consumer, token) @@ -338,7 +330,7 @@ def from_request(cls, http_method, http_url, headers=None, parameters=None, header_params = cls._split_header(auth_header) parameters.update(header_params) except: - raise OAuthError('Unable to parse OAuth parameters from ' + raise Error('Unable to parse OAuth parameters from ' 'Authorization header.') # GET or POST query string. @@ -375,12 +367,12 @@ def from_consumer_and_token(cls, oauth_consumer, token=None, if token: parameters['oauth_token'] = token.key - return OAuthRequest(http_method, http_url, parameters) + return Request(http_method, http_url, parameters) @classmethod def from_token_and_callback(cls, token, callback=None, - http_method=HTTP_METHOD, - http_url=None, parameters=None): + http_method=HTTP_METHOD, http_url=None, parameters=None): + if not parameters: parameters = {} @@ -448,7 +440,7 @@ def get_data_store(self): return self.data_store def add_signature_method(self, signature_method): - self.signature_methods[signature_method.get_name()] = signature_method + self.signature_methods[signature_method.name] = signature_method return self.signature_methods def fetch_request_token(self, oauth_request): @@ -573,13 +565,13 @@ def _check_signature(self, oauth_request, consumer, token): token, signature) if not valid_sig: - key, base = signature_method.build_signature_base_string( + key, base = signature_method.signing_base( oauth_request, consumer, token) raise Error('Invalid signature. Expected signature base ' 'string: %s' % base) - built = signature_method.build_signature(oauth_request, + built = signature_method.sign(oauth_request, consumer, token) def _check_timestamp(self, timestamp): @@ -672,34 +664,45 @@ class SignatureMethod(object): provide a new way to sign requests. """ - def get_name(self): - """-> str.""" - raise NotImplementedError + def signing_base(self, request, consumer, token): + """Calculates the string that needs to be signed. + + This method returns a 2-tuple containing the starting key for the + signing and the message to be signed. The latter may be used in error + messages to help clients debug their software. - def build_signature_base_string(self, oauth_request, - oauth_consumer, oauth_token): - """-> str key, str raw.""" + """ raise NotImplementedError - def build_signature(self, oauth_request, oauth_consumer, oauth_token): - """-> str.""" + def sign(self, request, consumer, token): + """Returns the signature for the given request, based on the consumer + and token also provided. + + You should use your implementation of `signing_base()` to build the + message to sign. Otherwise it may be less useful for debugging. + + """ raise NotImplementedError - def check_signature(self, oauth_request, consumer, token, signature): - built = self.build_signature(oauth_request, consumer, token) + def check(self, request, consumer, token, signature): + """Returns whether the given signature is the correct signature for + the given consumer and token signing the given request.""" + built = self.sign(request, consumer, token) return built == signature + build_signature_base_string = signing_base + build_signature = sign + check_signature = check + class SignatureMethod_HMAC_SHA1(SignatureMethod): - - def get_name(self): - return 'HMAC-SHA1' + name = 'HMAC-SHA1' - def build_signature_base_string(self, oauth_request, consumer, token): + def signing_base(self, request, consumer, token): sig = ( - escape(oauth_request.get_normalized_http_method()), - escape(oauth_request.get_normalized_http_url()), - escape(oauth_request.get_normalized_parameters()), + escape(request.method), + escape(request.url), + escape(request.get_normalized_parameters()), ) key = '%s&' % escape(consumer.secret) @@ -708,10 +711,9 @@ def build_signature_base_string(self, oauth_request, consumer, token): raw = '&'.join(sig) return key, raw - def build_signature(self, oauth_request, consumer, token): + def sign(self, request, consumer, token): """Builds the base signature string.""" - key, raw = self.build_signature_base_string(oauth_request, consumer, - token) + key, raw = self.signing_base(request, consumer, token) # HMAC object. try: @@ -724,23 +726,21 @@ def build_signature(self, oauth_request, consumer, token): # Calculate the digest base 64. return binascii.b2a_base64(hashed.digest())[:-1] - class SignatureMethod_PLAINTEXT(SignatureMethod): - def get_name(self): - return 'PLAINTEXT' + name = 'PLAINTEXT' - def build_signature_base_string(self, oauth_request, consumer, token): - """Concatenates the consumer key and secret.""" + def signing_base(self, request, consumer, token): + """Concatenates the consumer key and secret with the token's + secret.""" sig = '%s&' % escape(consumer.secret) if token: sig = sig + escape(token.secret) return sig, sig - def build_signature(self, oauth_request, consumer, token): - key, raw = self.build_signature_base_string(oauth_request, consumer, - token) - return key + def sign(self, request, consumer, token): + key, raw = self.signing_base(request, consumer, token) + return raw # Backwards compatibility OAuthError = Error diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 50b1feb..d8fc7df 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -24,6 +24,10 @@ import unittest import oauth +import time +import urllib +import urlparse +import cgi class TestError(unittest.TestCase): def test_message(self): @@ -66,6 +70,11 @@ def test_gen_verifier(self): verifier = oauth.generate_verifier(16) self.assertEqual(len(verifier), 16) + def test_gen_timestamp(self): + exp = int(time.time()) + now = oauth.generate_timestamp() + self.assertEqual(exp, now) + class TestConsumer(unittest.TestCase): def test_init(self): key = 'my-key' @@ -188,7 +197,257 @@ def test_from_string(self): self._compare_tokens(new) class TestRequest(unittest.TestCase): - pass + def test_setter(self): + url = "http://example.com" + method = "GET" + req = oauth.Request(method) + + try: + url = req.url + self.fail("AttributeError should have been raised on empty url.") + except AttributeError: + pass + except Exception, e: + self.fail(str(e)) + + def test_deleter(self): + url = "http://example.com" + method = "GET" + req = oauth.Request(method, url) + + try: + del req.url + url = req.url + self.fail("AttributeError should have been raised on empty url.") + except AttributeError: + pass + except Exception, e: + self.fail(str(e)) + + def test_url(self): + url1 = "http://example.com:80/foo.php" + url2 = "https://example.com:443/foo.php" + exp1 = "http://example.com/foo.php" + exp2 = "https://example.com/foo.php" + method = "GET" + + req = oauth.Request(method, url1) + self.assertEquals(req.url, exp1) + + req = oauth.Request(method, url2) + self.assertEquals(req.url, exp2) + + def test_get_nonoauth_parameters(self): + + oauth_params = { + 'oauth_consumer': 'asdfasdfasdf' + } + + other_params = { + 'foo': 'baz', + 'bar': 'foo' + } + + params = oauth_params + params.update(other_params) + + req = oauth.Request("GET", "http://example.com", params) + self.assertEquals(other_params, req.get_nonoauth_parameters()) + + def test_to_header(self): + realm = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", realm, params) + header, value = req.to_header(realm).items()[0] + + parts = value.split('OAuth ') + vars = parts[1].split(', ') + self.assertTrue(len(vars), (len(params) + 1)) + + res = {} + for v in vars: + var, val = v.split('=') + res[var] = urllib.unquote(val.strip('"')) + + self.assertEquals(realm, res['realm']) + del res['realm'] + + self.assertTrue(len(res), len(params)) + + for key, val in res.items(): + self.assertEquals(val, params.get(key)) + + def test_to_postdata(self): + realm = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", realm, params) + + self.assertEquals(params, dict(urlparse.parse_qsl(req.to_postdata()))) + + def test_to_url(self): + url = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", url, params) + exp = urlparse.urlparse("%s?%s" % (url, urllib.urlencode(params))) + res = urlparse.urlparse(req.to_url()) + self.assertEquals(exp.scheme, res.scheme) + self.assertEquals(exp.netloc, res.netloc) + self.assertEquals(exp.path, res.path) + + a = urlparse.parse_qs(exp.query) + b = urlparse.parse_qs(res.query) + self.assertEquals(a, b) + + def test_get_normalized_parameters(self): + url = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", url, params) + + res = dict(urlparse.parse_qsl(req.get_normalized_parameters())) + + foo = params.copy() + del foo['oauth_signature'] + self.assertEquals(foo, res) + + def test_sign_request(self): + url = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", url, params) + tok = oauth.Token(key="tok-test-key", secret="tok-test-secret") + con = oauth.Consumer(key="con-test-key", secret="con-test-secret") + + methods = { + 'broken': oauth.SignatureMethod_HMAC_SHA1(), + 'another': oauth.SignatureMethod_PLAINTEXT() + } + + for exp, method in methods.items(): + req.sign_request(method, con, tok) + self.assertEquals(req['oauth_signature_method'], method.name) + self.assertEquals(req['oauth_signature'], exp) + + def test_from_request(self): + url = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + req = oauth.Request("GET", url, params) + headers = req.to_header() + + # Test from the headers + req = oauth.Request.from_request("GET", url, headers) + self.assertEquals(req.method, "GET") + self.assertEquals(req.url, url) + + self.assertEquals(params, req.copy()) + + # Test with bad OAuth headers + bad_headers = { + 'Authorization' : 'OAuth this is a bad header' + } + + self.assertRaises(oauth.Error, oauth.Request.from_request, "GET", + url, bad_headers) + + # Test getting from query string + qs = urllib.urlencode(params) + req = oauth.Request.from_request("GET", url, query_string=qs) + + exp = cgi.parse_qs(qs, keep_blank_values=False) + for k, v in exp.iteritems(): + exp[k] = urllib.unquote(v[0]) + + self.assertEquals(exp, req.copy()) + + # Test that a boned from_request() call returns None + req = oauth.Request.from_request("GET", url) + self.assertEquals(None, req) + + def test_from_consumer_and_token(self): + url = "http://sp.example.com/" + + params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_consumer_key': "0685bd9184jfhq22", + 'oauth_signature_method': "HMAC-SHA1", + 'oauth_token': "ad180jjd733klru7", + 'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D", + } + + tok = oauth.Token(key="tok-test-key", secret="tok-test-secret") + con = oauth.Consumer(key="con-test-key", secret="con-test-secret") + + req = oauth.Request.from_consumer_and_token(con) + self.assertTrue(len(req.copy()) == 4) + + req = oauth.Request.from_consumer_and_token(con, token=tok, + http_method="GET", http_url=url) + + self.assertTrue('oauth_timestamp' in req) + self.assertTrue('oauth_nonce' in req) + self.assertEquals(req['oauth_version'], '1.0') + self.assertEquals(req['oauth_consumer_key'], con.key) + self.assertEquals(req['oauth_token'], tok.key) class TestServer(unittest.TestCase): pass