Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge pull request #70 from litl/0.5.1

0.5.1
  • Loading branch information...
commit 31f38abcbf336d234ecd05d177395a7223710fa5 2 parents 682a91d + d83ef04
@maxcountryman maxcountryman authored
View
1  .gitignore
@@ -1,4 +1,5 @@
docs/_build/*
+tests_output/*
dist/*
*.egg-info
*.pyc
View
4 CHANGELOG
@@ -1,3 +1,7 @@
+Changes in Version 0.5.1
+
+ * BUGFIX Added CaseInsensitiveDict to ensure headers are properly updated
+
Changes in Version 0.5.0
* Added CHANGELOG
View
6 rauth/__init__.py
@@ -10,14 +10,14 @@
...
- >>> service.get_access_token(code='...')
- >>> r = service.get('resource')
+ >>> session = service.get_auth_session(...)
+ >>> r = session.get('resource')
>>> print r.json
'''
__title__ = 'rauth'
-__version_info__ = ('0', '5', '0')
+__version_info__ = ('0', '5', '1')
__version__ = '.'.join(__version_info__)
__author__ = 'Max Countryman'
__license__ = 'MIT'
View
6 rauth/session.py
@@ -14,8 +14,9 @@
from urlparse import parse_qsl, urljoin, urlsplit
from rauth.oauth import HmacSha1Signature
-from rauth.utils import (absolute_url, ENTITY_METHODS, FORM_URLENCODED,
- get_sorted_params, OPTIONAL_OAUTH_PARAMS)
+from rauth.utils import (absolute_url, CaseInsensitiveDict, ENTITY_METHODS,
+ FORM_URLENCODED, get_sorted_params,
+ OPTIONAL_OAUTH_PARAMS)
from requests.sessions import Session
@@ -132,6 +133,7 @@ def request(self,
:type \*\*req_kwargs: dict
'''
req_kwargs.setdefault('headers', {})
+ req_kwargs['headers'] = CaseInsensitiveDict(req_kwargs['headers'])
url = self._set_url(url)
View
35 rauth/utils.py
@@ -8,6 +8,8 @@
from urlparse import parse_qsl
+from requests.structures import CaseInsensitiveDict as cidict
+
FORM_URLENCODED = 'application/x-www-form-urlencoded'
ENTITY_METHODS = ('POST', 'PUT', 'PATCH')
OPTIONAL_OAUTH_PARAMS = ('oauth_callback', 'oauth_verifier', 'oauth_version')
@@ -38,3 +40,36 @@ def sorting_gen():
for k in sorted(params.keys()):
yield '='.join((k, params[k]))
return '&'.join(sorting_gen())
+
+
+class CaseInsensitiveDict(cidict):
+ def __init__(self, d=None):
+ lowered_d = {}
+
+ if d is not None:
+ if isinstance(d, dict):
+ lowered_d = self._get_lowered_d(d)
+ elif isinstance(d, list):
+ return self.__init__(dict(d))
+
+ return super(CaseInsensitiveDict, self).__init__(lowered_d)
+
+ def _get_lowered_d(self, d):
+ lowered_d = {}
+ for key in d:
+ if isinstance(key, basestring):
+ lowered_d[key.lower()] = d[key]
+ else: # pragma: no cover
+ lowered_d[key] = d[key]
+ return lowered_d
+
+ def setdefault(self, key, default):
+ if isinstance(key, basestring):
+ key = key.lower()
+
+ super(CaseInsensitiveDict, self).setdefault(key, default)
+ self._clear_lower_keys()
+
+ def update(self, d):
+ super(CaseInsensitiveDict, self).update(self._get_lowered_d(d))
+ self._clear_lower_keys()
View
3  tests/test_service_oauth1.py
@@ -11,7 +11,7 @@
from rauth.service import OAuth1Service
from rauth.session import OAUTH1_DEFAULT_TIMEOUT, OAuth1Session
-from rauth.utils import ENTITY_METHODS, FORM_URLENCODED
+from rauth.utils import CaseInsensitiveDict, ENTITY_METHODS, FORM_URLENCODED
from copy import deepcopy
from hashlib import sha1
@@ -114,6 +114,7 @@ def fake_request(self,
kwargs['data'] = dict(parse_qsl(kwargs['data']))
kwargs.setdefault('headers', {})
+ kwargs['headers'] = CaseInsensitiveDict(kwargs['headers'])
oauth_params = {'oauth_consumer_key': session.consumer_key,
'oauth_nonce': fake_nonce,
View
22 tests/test_utils.py
@@ -7,7 +7,7 @@
'''
from base import RauthTestCase
-from rauth.utils import absolute_url, parse_utf8_qsl
+from rauth.utils import absolute_url, CaseInsensitiveDict, parse_utf8_qsl
class UtilsTestCase(RauthTestCase):
@@ -27,3 +27,23 @@ def test_parse_utf8_qsl(self):
def test_both_kv_unicode(self):
d = parse_utf8_qsl(u'fü=bar&rauth=über')
self.assertEqual(d, {u'rauth': u'\xfcber', u'f\xfc': u'bar'})
+
+ def test_rauth_case_insensitive_dict(self):
+ d = CaseInsensitiveDict()
+ d.setdefault('Content-Type', 'foo')
+
+ d.update({'content-type': 'bar'})
+
+ self.assertEqual(1, len(d.keys()))
+ self.assertIn('content-type', d.keys())
+ self.assertEqual({'content-type': 'bar'}, d)
+
+ d.update({'CONTENT-TYPE': 'baz'})
+
+ self.assertEqual(1, len(d.keys()))
+ self.assertIn('content-type', d.keys())
+ self.assertEqual({'content-type': 'baz'}, d)
+
+ def test_rauth_case_insensitive_dict_list_of_tuples(self):
+ d = CaseInsensitiveDict([('Content-Type', 'foo')])
+ self.assertEqual(d, {'content-type': 'foo'})
Please sign in to comment.
Something went wrong with that request. Please try again.