Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions kmsauth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import datetime
import base64
import os
import sys
import copy

from botocore.vendored import six
Expand All @@ -16,7 +15,22 @@

TOKEN_SKEW = 3
TIME_FORMAT = "%Y%m%dT%H%M%SZ"
PY2 = sys.version[0] == '2'


def ensure_text(str_or_bytes, encoding='utf-8'):
"""Ensures an input is a string, decoding if it is bytes.
"""
if not isinstance(str_or_bytes, six.text_type):
return str_or_bytes.decode(encoding)
return str_or_bytes


def ensure_bytes(str_or_bytes, encoding='utf-8', errors='strict'):
"""Ensures an input is bytes, encoding if it is a string.
"""
if isinstance(str_or_bytes, six.text_type):
return str_or_bytes.encode(encoding, errors)
return str_or_bytes


class KMSTokenValidator(object):
Expand Down Expand Up @@ -205,12 +219,8 @@ def decrypt_token(self, username, token):
version < self.minimum_token_version):
raise TokenValidationError('Unacceptable token version.')
try:
if PY2:
token_bytes = bytes(token)
else:
token_bytes = bytes(token, 'utf8')
token_key = '{0}{1}{2}{3}'.format(
hashlib.sha256(token_bytes).hexdigest(),
hashlib.sha256(ensure_bytes(token)).hexdigest(),
_from,
self.to_auth_context,
user_type
Expand Down Expand Up @@ -418,7 +428,7 @@ def _cache_token(self, token, not_after):
os.makedirs(cachedir)
with open(self.token_cache_file, 'w') as f:
json.dump({
'token': token,
'token': ensure_text(token),
'not_after': not_after,
'auth_context': self.auth_context
}, f)
Expand Down Expand Up @@ -470,11 +480,7 @@ def get_token(self):
Plaintext=payload,
EncryptionContext=self.auth_context
)['CiphertextBlob']
if PY2:
token_bytes = bytes(token)
else:
token_bytes = bytes(token, 'utf8')
token = base64.b64encode(token_bytes)
token = base64.b64encode(ensure_bytes(token))
except (ConnectionError, EndpointConnectionError) as e:
logging.exception('Failure connecting to AWS: {}'.format(str(e)))
raise ServiceConnectionError()
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ flake8==2.3.0
# Measures code coverage and emits coverage reports
# Licence: BSD
# Upstream url: https://pypi.python.org/pypi/coverage
coverage==3.7.1
coverage==4.4.2

# tool to check your Python code against some of the style conventions
# License: Expat License
Expand All @@ -22,7 +22,7 @@ pep8==1.5.7
# nose makes testing easier
# License: GNU Library or Lesser General Public License (LGPL)
# Upstream url: http://readthedocs.org/docs/nose
nose==1.3.3
nose==1.3.7

# Mocking and Patching Library for Testing
# License: BSD
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/kmsauth/kmsauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def test_get_username(self):
def test_get_token(self, boto_mock):
kms_mock = MagicMock()
kms_mock.encrypt = MagicMock(
return_value={'CiphertextBlob': 'encrypted'}
return_value={'CiphertextBlob': b'encrypted'}
)
boto_mock.return_value = kms_mock
client = kmsauth.KMSTokenGenerator(
Expand Down