Permalink
Browse files

Adding support for IAM Roles. Closes #811.

  • Loading branch information...
garnaat committed Jun 12, 2012
1 parent bc05cb1 commit 304f2ed9cd1c0c3828cb511a7a94933e5b3c3088
View
@@ -104,25 +104,29 @@ class AnonAuthHandler(AuthHandler, HmacKeys):
"""
Implements Anonymous requests.
"""
-
+
capability = ['anon']
-
+
def __init__(self, host, config, provider):
AuthHandler.__init__(self, host, config, provider)
-
+
def add_auth(self, http_request, **kwargs):
pass
class HmacAuthV1Handler(AuthHandler, HmacKeys):
""" Implements the HMAC request signing used by S3 and GS."""
-
+
capability = ['hmac-v1', 's3']
-
+
def __init__(self, host, config, provider):
AuthHandler.__init__(self, host, config, provider)
HmacKeys.__init__(self, host, config, provider)
self._hmac_256 = None
-
+
+ def update_provider(self, provider):
+ super(HmacAuthV1Handler, self).update_provider(provider)
+ self._hmac_256 = None
+
def add_auth(self, http_request, **kwargs):
headers = http_request.headers
method = http_request.method
@@ -148,12 +152,16 @@ class HmacAuthV2Handler(AuthHandler, HmacKeys):
Implements the simplified HMAC authorization used by CloudFront.
"""
capability = ['hmac-v2', 'cloudfront']
-
+
def __init__(self, host, config, provider):
AuthHandler.__init__(self, host, config, provider)
HmacKeys.__init__(self, host, config, provider)
self._hmac_256 = None
-
+
+ def update_provider(self, provider):
+ super(HmacAuthV2Handler, self).update_provider(provider)
+ self._hmac_256 = None
+
def add_auth(self, http_request, **kwargs):
headers = http_request.headers
if 'Date' not in headers:
@@ -164,16 +172,16 @@ def add_auth(self, http_request, **kwargs):
headers['Authorization'] = ("%s %s:%s" %
(auth_hdr,
self._provider.access_key, b64_hmac))
-
+
class HmacAuthV3Handler(AuthHandler, HmacKeys):
"""Implements the new Version 3 HMAC authorization used by Route53."""
-
+
capability = ['hmac-v3', 'route53', 'ses']
-
+
def __init__(self, host, config, provider):
AuthHandler.__init__(self, host, config, provider)
HmacKeys.__init__(self, host, config, provider)
-
+
def add_auth(self, http_request, **kwargs):
headers = http_request.headers
if 'Date' not in headers:
@@ -188,9 +196,9 @@ class HmacAuthV3HTTPHandler(AuthHandler, HmacKeys):
"""
Implements the new Version 3 HMAC authorization used by DynamoDB.
"""
-
+
capability = ['hmac-v3-http']
-
+
def __init__(self, host, config, provider):
AuthHandler.__init__(self, host, config, provider)
HmacKeys.__init__(self, host, config, provider)
@@ -234,7 +242,7 @@ def string_to_sign(self, http_request):
'',
http_request.body])
return string_to_sign, headers_to_sign
-
+
def add_auth(self, req, **kwargs):
"""
Add AWS3 authentication to a request.
@@ -367,7 +375,7 @@ def _calc_signature(self, params, verb, path, server_name):
class POSTPathQSV2AuthHandler(QuerySignatureV2AuthHandler, AuthHandler):
"""
Query Signature V2 Authentication relocating signed query
- into the path and allowing POST requests with Content-Types.
+ into the path and allowing POST requests with Content-Types.
"""
capability = ['mws']
@@ -401,7 +409,7 @@ def get_auth_handler(host, config, provider, requested_capability=None):
:type host: string
:param host: The name of the host
- :type config:
+ :type config:
:param config:
:type provider:
@@ -422,13 +430,13 @@ def get_auth_handler(host, config, provider, requested_capability=None):
ready_handlers.append(handler(host, config, provider))
except boto.auth_handler.NotReadyToAuthenticate:
pass
-
+
if not ready_handlers:
checked_handlers = auth_handlers
names = [handler.__name__ for handler in checked_handlers]
raise boto.exception.NoAuthHandlerFound(
'No handler was ready to authenticate. %d handlers were checked.'
- ' %s '
+ ' %s '
'Check your credentials' % (len(names), str(names)))
if len(ready_handlers) > 1:
View
@@ -55,6 +55,7 @@
import time
import urllib, urlparse
import xml.sax
+from xml.etree import ElementTree
import auth
import auth_handler
@@ -477,7 +478,8 @@ def __init__(self, host, aws_access_key_id=None, aws_secret_access_key=None,
# Allow overriding Provider
self.provider = provider
else:
- self.provider = Provider(provider,
+ self._provider_type = provider
+ self.provider = Provider(self._provider_type,
aws_access_key_id,
aws_secret_access_key,
security_token)
@@ -733,6 +735,10 @@ def _mexe(self, request, sender=None, override_num_retries=None,
num_retries = override_num_retries
i = 0
connection = self.get_http_connection(request.host, self.is_secure)
+ # The original headers/params are stored so that we can restore them
+ # if credentials are refreshed.
+ original_headers = request.headers.copy()
+ original_params = request.params.copy()
while i <= num_retries:
# Use binary exponential backoff to desynchronize client requests
next_sleep = random.random() * (2 ** i)
@@ -767,6 +773,13 @@ def _mexe(self, request, sender=None, override_num_retries=None,
msg += 'Retrying in %3.1f seconds' % next_sleep
boto.log.debug(msg)
body = response.read()
+ elif self._credentials_expired(response):
+ # The same request object is used so the security token and
+ # access key params are cleared because they are no longer
+ # valid.
+ request.params = original_params.copy()
+ request.headers = original_headers.copy()
+ self._renew_credentials()
elif response.status < 300 or response.status >= 400 or \
not location:
self.put_http_connection(request.host, self.is_secure,
@@ -809,6 +822,31 @@ def _mexe(self, request, sender=None, override_num_retries=None,
msg = 'Please report this exception as a Boto Issue!'
raise BotoClientError(msg)
+ def _credentials_expired(self, response):
+ # It is possible that we could be using temporary credentials that are
+ # now expired. We want to detect when this happens so that we can
+ # refresh the credentials. Subclasses can override this method and
+ # determine whether or not the response indicates that the credentials
+ # are invalid. If this method returns True, the credentials will be
+ # renewed.
+ if response.status != 403:
+ return False
+ try:
+ for event, node in ElementTree.iterparse(response, events=['start']):
+ if node.tag.endswith('Code'):
+ if node.text == 'ExpiredToken':
+ return True
+ except ElementTree.ParseError:
+ return False
+ return False
+
+ def _renew_credentials(self):
+ # By resetting the provider with a new provider, this will trigger the
+ # lookup process for finding the new set of credentials.
+ boto.log.debug("Refreshing credentials.")
+ self.provider = Provider(self._provider_type)
+ self._auth_handler.update_provider(self.provider)
+
def build_base_http_request(self, method, path, auth_path,
params=None, headers=None, data='', host=None):
path = self.get_path(path)
View
@@ -28,6 +28,8 @@
import warnings
from datetime import datetime
from datetime import timedelta
+from xml.etree import ElementTree
+
import boto
from boto.connection import AWSQueryConnection
from boto.resultset import ResultSet
@@ -60,7 +62,7 @@
class EC2Connection(AWSQueryConnection):
- APIVersion = boto.config.get('Boto', 'ec2_version', '2012-03-01')
+ APIVersion = boto.config.get('Boto', 'ec2_version', '2012-06-01')
DefaultRegionName = boto.config.get('Boto', 'ec2_region_name', 'us-east-1')
DefaultRegionEndpoint = boto.config.get('Boto', 'ec2_region_endpoint',
'ec2.us-east-1.amazonaws.com')
@@ -92,6 +94,15 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
def _required_auth_capability(self):
return ['ec2']
+ def _credentials_expired(self, response):
+ if response.status != 400:
+ return False
+ for event, node in ElementTree.iterparse(response, events=['start']):
+ if node.tag.endswith('Code'):
+ if node.text == 'RequestExpired':
+ return True
+ return False
+
def get_params(self):
"""
Returns a dictionary containing the value of of all of the keyword
@@ -520,7 +531,8 @@ def run_instances(self, image_id, min_count=1, max_count=1,
private_ip_address=None,
placement_group=None, client_token=None,
security_group_ids=None,
- additional_info=None, tenancy=None):
+ additional_info=None, instance_profile_name=None,
+ instance_profile_arn=None, tenancy=None):
"""
Runs an image on EC2.
@@ -688,6 +700,10 @@ def run_instances(self, image_id, min_count=1, max_count=1,
params['ClientToken'] = client_token
if additional_info:
params['AdditionalInfo'] = additional_info
+ if instance_profile_name:
+ params['IamInstanceProfile.Name'] = instance_profile_name
+ if instance_profile_arn:
+ params['IamInstanceProfile.Arn'] = instance_profile_arn
return self.get_object('RunInstances', params, Reservation, verb='POST')
def terminate_instances(self, instance_ids=None):
View
@@ -184,6 +184,9 @@ def startElement(self, name, attrs, connection):
return self.eventsSet
elif name == 'networkInterfaceSet':
self.interfaces = ResultSet([('item', NetworkInterface)])
+ elif name == 'iamInstanceProfile':
+ self.instance_profile = SubParse('iamInstanceProfile')
+ return self.instance_profile
return None
def endElement(self, name, value, connection):
Oops, something went wrong.

0 comments on commit 304f2ed

Please sign in to comment.