Skip to content
Permalink
e70cc1d46d
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
1803 lines (1517 sloc) 65.1 KB
# Copyright 2012-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import base64
import re
import time
import logging
import datetime
import hashlib
import binascii
import functools
import weakref
import random
import os
import socket
import cgi
import dateutil.parser
from dateutil.tz import tzutc
import botocore
import botocore.awsrequest
import botocore.httpsession
from botocore.compat import (
json, quote, zip_longest, urlsplit, urlunsplit, OrderedDict,
six, urlparse, get_tzinfo_options, get_md5, MD5_AVAILABLE
)
from botocore.vendored.six.moves.urllib.request import getproxies, proxy_bypass
from botocore.exceptions import (
InvalidExpressionError, ConfigNotFound, InvalidDNSNameError, ClientError,
MetadataRetrievalError, EndpointConnectionError, ReadTimeoutError,
ConnectionClosedError, ConnectTimeoutError, UnsupportedS3ArnError,
UnsupportedS3AccesspointConfigurationError, SSOTokenLoadError,
InvalidRegionError,
)
logger = logging.getLogger(__name__)
DEFAULT_METADATA_SERVICE_TIMEOUT = 1
METADATA_BASE_URL = 'http://169.254.169.254/'
# These are chars that do not need to be urlencoded.
# Based on rfc2986, section 2.3
SAFE_CHARS = '-._~'
LABEL_RE = re.compile(r'[a-z0-9][a-z0-9\-]*[a-z0-9]')
RETRYABLE_HTTP_ERRORS = (
ReadTimeoutError, EndpointConnectionError, ConnectionClosedError,
ConnectTimeoutError,
)
S3_ACCELERATE_WHITELIST = ['dualstack']
# In switching events from using service name / endpoint prefix to service
# id, we have to preserve compatibility. This maps the instances where either
# is different than the transformed service id.
EVENT_ALIASES = {
"a4b": "alexa-for-business",
"alexaforbusiness": "alexa-for-business",
"api.mediatailor": "mediatailor",
"api.pricing": "pricing",
"api.sagemaker": "sagemaker",
"apigateway": "api-gateway",
"application-autoscaling": "application-auto-scaling",
"appstream2": "appstream",
"autoscaling": "auto-scaling",
"autoscaling-plans": "auto-scaling-plans",
"ce": "cost-explorer",
"cloudhsmv2": "cloudhsm-v2",
"cloudsearchdomain": "cloudsearch-domain",
"cognito-idp": "cognito-identity-provider",
"config": "config-service",
"cur": "cost-and-usage-report-service",
"data.iot": "iot-data-plane",
"data.jobs.iot": "iot-jobs-data-plane",
"data.mediastore": "mediastore-data",
"datapipeline": "data-pipeline",
"devicefarm": "device-farm",
"devices.iot1click": "iot-1click-devices-service",
"directconnect": "direct-connect",
"discovery": "application-discovery-service",
"dms": "database-migration-service",
"ds": "directory-service",
"dynamodbstreams": "dynamodb-streams",
"elasticbeanstalk": "elastic-beanstalk",
"elasticfilesystem": "efs",
"elasticloadbalancing": "elastic-load-balancing",
"elasticmapreduce": "emr",
"elastictranscoder": "elastic-transcoder",
"elb": "elastic-load-balancing",
"elbv2": "elastic-load-balancing-v2",
"email": "ses",
"entitlement.marketplace": "marketplace-entitlement-service",
"es": "elasticsearch-service",
"events": "eventbridge",
"cloudwatch-events": "eventbridge",
"iot-data": "iot-data-plane",
"iot-jobs-data": "iot-jobs-data-plane",
"iot1click-devices": "iot-1click-devices-service",
"iot1click-projects": "iot-1click-projects",
"kinesisanalytics": "kinesis-analytics",
"kinesisvideo": "kinesis-video",
"lex-models": "lex-model-building-service",
"lex-runtime": "lex-runtime-service",
"logs": "cloudwatch-logs",
"machinelearning": "machine-learning",
"marketplace-entitlement": "marketplace-entitlement-service",
"marketplacecommerceanalytics": "marketplace-commerce-analytics",
"metering.marketplace": "marketplace-metering",
"meteringmarketplace": "marketplace-metering",
"mgh": "migration-hub",
"models.lex": "lex-model-building-service",
"monitoring": "cloudwatch",
"mturk-requester": "mturk",
"opsworks-cm": "opsworkscm",
"projects.iot1click": "iot-1click-projects",
"resourcegroupstaggingapi": "resource-groups-tagging-api",
"route53": "route-53",
"route53domains": "route-53-domains",
"runtime.lex": "lex-runtime-service",
"runtime.sagemaker": "sagemaker-runtime",
"sdb": "simpledb",
"secretsmanager": "secrets-manager",
"serverlessrepo": "serverlessapplicationrepository",
"servicecatalog": "service-catalog",
"states": "sfn",
"stepfunctions": "sfn",
"storagegateway": "storage-gateway",
"streams.dynamodb": "dynamodb-streams",
"tagging": "resource-groups-tagging-api"
}
def ensure_boolean(val):
"""Ensures a boolean value if a string or boolean is provided
For strings, the value for True/False is case insensitive
"""
if isinstance(val, bool):
return val
else:
return val.lower() == 'true'
def is_json_value_header(shape):
"""Determines if the provided shape is the special header type jsonvalue.
:type shape: botocore.shape
:param shape: Shape to be inspected for the jsonvalue trait.
:return: True if this type is a jsonvalue, False otherwise
:rtype: Bool
"""
return (hasattr(shape, 'serialization') and
shape.serialization.get('jsonvalue', False) and
shape.serialization.get('location') == 'header' and
shape.type_name == 'string')
def get_service_module_name(service_model):
"""Returns the module name for a service
This is the value used in both the documentation and client class name
"""
name = service_model.metadata.get(
'serviceAbbreviation',
service_model.metadata.get(
'serviceFullName', service_model.service_name))
name = name.replace('Amazon', '')
name = name.replace('AWS', '')
name = re.sub(r'\W+', '', name)
return name
def normalize_url_path(path):
if not path:
return '/'
return remove_dot_segments(path)
def remove_dot_segments(url):
# RFC 3986, section 5.2.4 "Remove Dot Segments"
# Also, AWS services require consecutive slashes to be removed,
# so that's done here as well
if not url:
return ''
input_url = url.split('/')
output_list = []
for x in input_url:
if x and x != '.':
if x == '..':
if output_list:
output_list.pop()
else:
output_list.append(x)
if url[0] == '/':
first = '/'
else:
first = ''
if url[-1] == '/' and output_list:
last = '/'
else:
last = ''
return first + '/'.join(output_list) + last
def validate_jmespath_for_set(expression):
# Validates a limited jmespath expression to determine if we can set a
# value based on it. Only works with dotted paths.
if not expression or expression == '.':
raise InvalidExpressionError(expression=expression)
for invalid in ['[', ']', '*']:
if invalid in expression:
raise InvalidExpressionError(expression=expression)
def set_value_from_jmespath(source, expression, value, is_first=True):
# This takes a (limited) jmespath-like expression & can set a value based
# on it.
# Limitations:
# * Only handles dotted lookups
# * No offsets/wildcards/slices/etc.
if is_first:
validate_jmespath_for_set(expression)
bits = expression.split('.', 1)
current_key, remainder = bits[0], bits[1] if len(bits) > 1 else ''
if not current_key:
raise InvalidExpressionError(expression=expression)
if remainder:
if current_key not in source:
# We've got something in the expression that's not present in the
# source (new key). If there's any more bits, we'll set the key
# with an empty dictionary.
source[current_key] = {}
return set_value_from_jmespath(
source[current_key],
remainder,
value,
is_first=False
)
# If we're down to a single key, set it.
source[current_key] = value
class _RetriesExceededError(Exception):
"""Internal exception used when the number of retries are exceeded."""
pass
class BadIMDSRequestError(Exception):
def __init__(self, request):
self.request = request
class IMDSFetcher(object):
_RETRIES_EXCEEDED_ERROR_CLS = _RetriesExceededError
_TOKEN_PATH = 'latest/api/token'
_TOKEN_TTL = '21600'
def __init__(self, timeout=DEFAULT_METADATA_SERVICE_TIMEOUT,
num_attempts=1, base_url=METADATA_BASE_URL,
env=None, user_agent=None):
self._timeout = timeout
self._num_attempts = num_attempts
self._base_url = base_url
if env is None:
env = os.environ.copy()
self._disabled = env.get('AWS_EC2_METADATA_DISABLED', 'false').lower()
self._disabled = self._disabled == 'true'
self._user_agent = user_agent
self._session = botocore.httpsession.URLLib3Session(
timeout=self._timeout,
proxies=get_environ_proxies(self._base_url),
)
def _fetch_metadata_token(self):
self._assert_enabled()
url = self._base_url + self._TOKEN_PATH
headers = {
'x-aws-ec2-metadata-token-ttl-seconds': self._TOKEN_TTL,
}
self._add_user_agent(headers)
request = botocore.awsrequest.AWSRequest(
method='PUT', url=url, headers=headers)
for i in range(self._num_attempts):
try:
response = self._session.send(request.prepare())
if response.status_code == 200:
return response.text
elif response.status_code in (404, 403, 405):
return None
elif response.status_code in (400,):
raise BadIMDSRequestError(request)
except ReadTimeoutError:
return None
except RETRYABLE_HTTP_ERRORS as e:
logger.debug(
"Caught retryable HTTP exception while making metadata "
"service request to %s: %s", url, e, exc_info=True)
return None
def _get_request(self, url_path, retry_func, token=None):
"""Make a get request to the Instance Metadata Service.
:type url_path: str
:param url_path: The path component of the URL to make a get request.
This arg is appended to the base_url that was provided in the
initializer.
:type retry_func: callable
:param retry_func: A function that takes the response as an argument
and determines if it needs to retry. By default empty and non
200 OK responses are retried.
:type token: str
:param token: Metadata token to send along with GET requests to IMDS.
"""
self._assert_enabled()
if retry_func is None:
retry_func = self._default_retry
url = self._base_url + url_path
headers = {}
if token is not None:
headers['x-aws-ec2-metadata-token'] = token
self._add_user_agent(headers)
for i in range(self._num_attempts):
try:
request = botocore.awsrequest.AWSRequest(
method='GET', url=url, headers=headers)
response = self._session.send(request.prepare())
if not retry_func(response):
return response
except RETRYABLE_HTTP_ERRORS as e:
logger.debug(
"Caught retryable HTTP exception while making metadata "
"service request to %s: %s", url, e, exc_info=True)
raise self._RETRIES_EXCEEDED_ERROR_CLS()
def _add_user_agent(self, headers):
if self._user_agent is not None:
headers['User-Agent'] = self._user_agent
def _assert_enabled(self):
if self._disabled:
logger.debug("Access to EC2 metadata has been disabled.")
raise self._RETRIES_EXCEEDED_ERROR_CLS()
def _default_retry(self, response):
return (
self._is_non_ok_response(response) or
self._is_empty(response)
)
def _is_non_ok_response(self, response):
if response.status_code != 200:
self._log_imds_response(response, 'non-200', log_body=True)
return True
return False
def _is_empty(self, response):
if not response.content:
self._log_imds_response(response, 'no body', log_body=True)
return True
return False
def _log_imds_response(self, response, reason_to_log, log_body=False):
statement = (
"Metadata service returned %s response "
"with status code of %s for url: %s"
)
logger_args = [
reason_to_log, response.status_code, response.url
]
if log_body:
statement += ", content body: %s"
logger_args.append(response.content)
logger.debug(statement, *logger_args)
class InstanceMetadataFetcher(IMDSFetcher):
_URL_PATH = 'latest/meta-data/iam/security-credentials/'
_REQUIRED_CREDENTIAL_FIELDS = [
'AccessKeyId', 'SecretAccessKey', 'Token', 'Expiration'
]
def retrieve_iam_role_credentials(self):
try:
token = self._fetch_metadata_token()
role_name = self._get_iam_role(token)
credentials = self._get_credentials(role_name, token)
if self._contains_all_credential_fields(credentials):
return {
'role_name': role_name,
'access_key': credentials['AccessKeyId'],
'secret_key': credentials['SecretAccessKey'],
'token': credentials['Token'],
'expiry_time': credentials['Expiration'],
}
else:
# IMDS can return a 200 response that has a JSON formatted
# error message (i.e. if ec2 is not trusted entity for the
# attached role). We do not necessarily want to retry for
# these and we also do not necessarily want to raise a key
# error. So at least log the problematic response and return
# an empty dictionary to signal that it was not able to
# retrieve credentials. These error will contain both a
# Code and Message key.
if 'Code' in credentials and 'Message' in credentials:
logger.debug('Error response received when retrieving'
'credentials: %s.', credentials)
return {}
except self._RETRIES_EXCEEDED_ERROR_CLS:
logger.debug("Max number of attempts exceeded (%s) when "
"attempting to retrieve data from metadata service.",
self._num_attempts)
except BadIMDSRequestError as e:
logger.debug("Bad IMDS request: %s", e.request)
return {}
def _get_iam_role(self, token=None):
return self._get_request(
url_path=self._URL_PATH,
retry_func=self._needs_retry_for_role_name,
token=token,
).text
def _get_credentials(self, role_name, token=None):
r = self._get_request(
url_path=self._URL_PATH + role_name,
retry_func=self._needs_retry_for_credentials,
token=token,
)
return json.loads(r.text)
def _is_invalid_json(self, response):
try:
json.loads(response.text)
return False
except ValueError:
self._log_imds_response(response, 'invalid json')
return True
def _needs_retry_for_role_name(self, response):
return (
self._is_non_ok_response(response) or
self._is_empty(response)
)
def _needs_retry_for_credentials(self, response):
return (
self._is_non_ok_response(response) or
self._is_empty(response) or
self._is_invalid_json(response)
)
def _contains_all_credential_fields(self, credentials):
for field in self._REQUIRED_CREDENTIAL_FIELDS:
if field not in credentials:
logger.debug(
'Retrieved credentials is missing required field: %s',
field)
return False
return True
def merge_dicts(dict1, dict2, append_lists=False):
"""Given two dict, merge the second dict into the first.
The dicts can have arbitrary nesting.
:param append_lists: If true, instead of clobbering a list with the new
value, append all of the new values onto the original list.
"""
for key in dict2:
if isinstance(dict2[key], dict):
if key in dict1 and key in dict2:
merge_dicts(dict1[key], dict2[key])
else:
dict1[key] = dict2[key]
# If the value is a list and the ``append_lists`` flag is set,
# append the new values onto the original list
elif isinstance(dict2[key], list) and append_lists:
# The value in dict1 must be a list in order to append new
# values onto it.
if key in dict1 and isinstance(dict1[key], list):
dict1[key].extend(dict2[key])
else:
dict1[key] = dict2[key]
else:
# At scalar types, we iterate and merge the
# current dict that we're on.
dict1[key] = dict2[key]
def lowercase_dict(original):
"""Copies the given dictionary ensuring all keys are lowercase strings. """
copy = {}
for key in original:
copy[key.lower()] = original[key]
return copy
def parse_key_val_file(filename, _open=open):
try:
with _open(filename) as f:
contents = f.read()
return parse_key_val_file_contents(contents)
except OSError:
raise ConfigNotFound(path=filename)
def parse_key_val_file_contents(contents):
# This was originally extracted from the EC2 credential provider, which was
# fairly lenient in its parsing. We only try to parse key/val pairs if
# there's a '=' in the line.
final = {}
for line in contents.splitlines():
if '=' not in line:
continue
key, val = line.split('=', 1)
key = key.strip()
val = val.strip()
final[key] = val
return final
def percent_encode_sequence(mapping, safe=SAFE_CHARS):
"""Urlencode a dict or list into a string.
This is similar to urllib.urlencode except that:
* It uses quote, and not quote_plus
* It has a default list of safe chars that don't need
to be encoded, which matches what AWS services expect.
If any value in the input ``mapping`` is a list type,
then each list element wil be serialized. This is the equivalent
to ``urlencode``'s ``doseq=True`` argument.
This function should be preferred over the stdlib
``urlencode()`` function.
:param mapping: Either a dict to urlencode or a list of
``(key, value)`` pairs.
"""
encoded_pairs = []
if hasattr(mapping, 'items'):
pairs = mapping.items()
else:
pairs = mapping
for key, value in pairs:
if isinstance(value, list):
for element in value:
encoded_pairs.append('%s=%s' % (percent_encode(key),
percent_encode(element)))
else:
encoded_pairs.append('%s=%s' % (percent_encode(key),
percent_encode(value)))
return '&'.join(encoded_pairs)
def percent_encode(input_str, safe=SAFE_CHARS):
"""Urlencodes a string.
Whereas percent_encode_sequence handles taking a dict/sequence and
producing a percent encoded string, this function deals only with
taking a string (not a dict/sequence) and percent encoding it.
If given the binary type, will simply URL encode it. If given the
text type, will produce the binary type by UTF-8 encoding the
text. If given something else, will convert it to the text type
first.
"""
# If its not a binary or text string, make it a text string.
if not isinstance(input_str, (six.binary_type, six.text_type)):
input_str = six.text_type(input_str)
# If it's not bytes, make it bytes by UTF-8 encoding it.
if not isinstance(input_str, six.binary_type):
input_str = input_str.encode('utf-8')
return quote(input_str, safe=safe)
def _parse_timestamp_with_tzinfo(value, tzinfo):
"""Parse timestamp with pluggable tzinfo options."""
if isinstance(value, (int, float)):
# Possibly an epoch time.
return datetime.datetime.fromtimestamp(value, tzinfo())
else:
try:
return datetime.datetime.fromtimestamp(float(value), tzinfo())
except (TypeError, ValueError):
pass
try:
# In certain cases, a timestamp marked with GMT can be parsed into a
# different time zone, so here we provide a context which will
# enforce that GMT == UTC.
return dateutil.parser.parse(value, tzinfos={'GMT': tzutc()})
except (TypeError, ValueError) as e:
raise ValueError('Invalid timestamp "%s": %s' % (value, e))
def parse_timestamp(value):
"""Parse a timestamp into a datetime object.
Supported formats:
* iso8601
* rfc822
* epoch (value is an integer)
This will return a ``datetime.datetime`` object.
"""
for tzinfo in get_tzinfo_options():
try:
return _parse_timestamp_with_tzinfo(value, tzinfo)
except OSError as e:
logger.debug('Unable to parse timestamp with "%s" timezone info.',
tzinfo.__name__, exc_info=e)
raise RuntimeError('Unable to calculate correct timezone offset for '
'"%s"' % value)
def parse_to_aware_datetime(value):
"""Converted the passed in value to a datetime object with tzinfo.
This function can be used to normalize all timestamp inputs. This
function accepts a number of different types of inputs, but
will always return a datetime.datetime object with time zone
information.
The input param ``value`` can be one of several types:
* A datetime object (both naive and aware)
* An integer representing the epoch time (can also be a string
of the integer, i.e '0', instead of 0). The epoch time is
considered to be UTC.
* An iso8601 formatted timestamp. This does not need to be
a complete timestamp, it can contain just the date portion
without the time component.
The returned value will be a datetime object that will have tzinfo.
If no timezone info was provided in the input value, then UTC is
assumed, not local time.
"""
# This is a general purpose method that handles several cases of
# converting the provided value to a string timestamp suitable to be
# serialized to an http request. It can handle:
# 1) A datetime.datetime object.
if isinstance(value, datetime.datetime):
datetime_obj = value
else:
# 2) A string object that's formatted as a timestamp.
# We document this as being an iso8601 timestamp, although
# parse_timestamp is a bit more flexible.
datetime_obj = parse_timestamp(value)
if datetime_obj.tzinfo is None:
# I think a case would be made that if no time zone is provided,
# we should use the local time. However, to restore backwards
# compat, the previous behavior was to assume UTC, which is
# what we're going to do here.
datetime_obj = datetime_obj.replace(tzinfo=tzutc())
else:
datetime_obj = datetime_obj.astimezone(tzutc())
return datetime_obj
def datetime2timestamp(dt, default_timezone=None):
"""Calculate the timestamp based on the given datetime instance.
:type dt: datetime
:param dt: A datetime object to be converted into timestamp
:type default_timezone: tzinfo
:param default_timezone: If it is provided as None, we treat it as tzutc().
But it is only used when dt is a naive datetime.
:returns: The timestamp
"""
epoch = datetime.datetime(1970, 1, 1)
if dt.tzinfo is None:
if default_timezone is None:
default_timezone = tzutc()
dt = dt.replace(tzinfo=default_timezone)
d = dt.replace(tzinfo=None) - dt.utcoffset() - epoch
if hasattr(d, "total_seconds"):
return d.total_seconds() # Works in Python 2.7+
return (d.microseconds + (d.seconds + d.days * 24 * 3600) * 10**6) / 10**6
def calculate_sha256(body, as_hex=False):
"""Calculate a sha256 checksum.
This method will calculate the sha256 checksum of a file like
object. Note that this method will iterate through the entire
file contents. The caller is responsible for ensuring the proper
starting position of the file and ``seek()``'ing the file back
to its starting location if other consumers need to read from
the file like object.
:param body: Any file like object. The file must be opened
in binary mode such that a ``.read()`` call returns bytes.
:param as_hex: If True, then the hex digest is returned.
If False, then the digest (as binary bytes) is returned.
:returns: The sha256 checksum
"""
checksum = hashlib.sha256()
for chunk in iter(lambda: body.read(1024 * 1024), b''):
checksum.update(chunk)
if as_hex:
return checksum.hexdigest()
else:
return checksum.digest()
def calculate_tree_hash(body):
"""Calculate a tree hash checksum.
For more information see:
http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html
:param body: Any file like object. This has the same constraints as
the ``body`` param in calculate_sha256
:rtype: str
:returns: The hex version of the calculated tree hash
"""
chunks = []
required_chunk_size = 1024 * 1024
sha256 = hashlib.sha256
for chunk in iter(lambda: body.read(required_chunk_size), b''):
chunks.append(sha256(chunk).digest())
if not chunks:
return sha256(b'').hexdigest()
while len(chunks) > 1:
new_chunks = []
for first, second in _in_pairs(chunks):
if second is not None:
new_chunks.append(sha256(first + second).digest())
else:
# We're at the end of the list and there's no pair left.
new_chunks.append(first)
chunks = new_chunks
return binascii.hexlify(chunks[0]).decode('ascii')
def _in_pairs(iterable):
# Creates iterator that iterates over the list in pairs:
# for a, b in _in_pairs([0, 1, 2, 3, 4]):
# print(a, b)
#
# will print:
# 0, 1
# 2, 3
# 4, None
shared_iter = iter(iterable)
# Note that zip_longest is a compat import that uses
# the itertools izip_longest. This creates an iterator,
# this call below does _not_ immediately create the list
# of pairs.
return zip_longest(shared_iter, shared_iter)
class CachedProperty(object):
"""A read only property that caches the initially computed value.
This descriptor will only call the provided ``fget`` function once.
Subsequent access to this property will return the cached value.
"""
def __init__(self, fget):
self._fget = fget
def __get__(self, obj, cls):
if obj is None:
return self
else:
computed_value = self._fget(obj)
obj.__dict__[self._fget.__name__] = computed_value
return computed_value
class ArgumentGenerator(object):
"""Generate sample input based on a shape model.
This class contains a ``generate_skeleton`` method that will take
an input/output shape (created from ``botocore.model``) and generate
a sample dictionary corresponding to the input/output shape.
The specific values used are place holder values. For strings either an
empty string or the member name can be used, for numbers 0 or 0.0 is used.
The intended usage of this class is to generate the *shape* of the input
structure.
This can be useful for operations that have complex input shapes.
This allows a user to just fill in the necessary data instead of
worrying about the specific structure of the input arguments.
Example usage::
s = botocore.session.get_session()
ddb = s.get_service_model('dynamodb')
arg_gen = ArgumentGenerator()
sample_input = arg_gen.generate_skeleton(
ddb.operation_model('CreateTable').input_shape)
print("Sample input for dynamodb.CreateTable: %s" % sample_input)
"""
def __init__(self, use_member_names=False):
self._use_member_names = use_member_names
def generate_skeleton(self, shape):
"""Generate a sample input.
:type shape: ``botocore.model.Shape``
:param shape: The input shape.
:return: The generated skeleton input corresponding to the
provided input shape.
"""
stack = []
return self._generate_skeleton(shape, stack)
def _generate_skeleton(self, shape, stack, name=''):
stack.append(shape.name)
try:
if shape.type_name == 'structure':
return self._generate_type_structure(shape, stack)
elif shape.type_name == 'list':
return self._generate_type_list(shape, stack)
elif shape.type_name == 'map':
return self._generate_type_map(shape, stack)
elif shape.type_name == 'string':
if self._use_member_names:
return name
if shape.enum:
return random.choice(shape.enum)
return ''
elif shape.type_name in ['integer', 'long']:
return 0
elif shape.type_name == 'float':
return 0.0
elif shape.type_name == 'boolean':
return True
elif shape.type_name == 'timestamp':
return datetime.datetime(1970, 1, 1, 0, 0, 0)
finally:
stack.pop()
def _generate_type_structure(self, shape, stack):
if stack.count(shape.name) > 1:
return {}
skeleton = OrderedDict()
for member_name, member_shape in shape.members.items():
skeleton[member_name] = self._generate_skeleton(
member_shape, stack, name=member_name)
return skeleton
def _generate_type_list(self, shape, stack):
# For list elements we've arbitrarily decided to
# return two elements for the skeleton list.
name = ''
if self._use_member_names:
name = shape.member.name
return [
self._generate_skeleton(shape.member, stack, name),
]
def _generate_type_map(self, shape, stack):
key_shape = shape.key
value_shape = shape.value
assert key_shape.type_name == 'string'
return OrderedDict([
('KeyName', self._generate_skeleton(value_shape, stack)),
])
def is_valid_endpoint_url(endpoint_url):
"""Verify the endpoint_url is valid.
:type endpoint_url: string
:param endpoint_url: An endpoint_url. Must have at least a scheme
and a hostname.
:return: True if the endpoint url is valid. False otherwise.
"""
parts = urlsplit(endpoint_url)
hostname = parts.hostname
if hostname is None:
return False
if len(hostname) > 255:
return False
if hostname[-1] == ".":
hostname = hostname[:-1]
allowed = re.compile(
r"^((?!-)[A-Z\d-]{1,63}(?<!-)\.)*((?!-)[A-Z\d-]{1,63}(?<!-))$",
re.IGNORECASE)
return allowed.match(hostname)
def validate_region_name(region_name):
"""Provided region_name must be a valid host label."""
if region_name is None:
return
valid_host_label = re.compile(r'^(?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(?<!-)$')
valid = valid_host_label.match(region_name)
if not valid:
raise InvalidRegionError(region_name=region_name)
def check_dns_name(bucket_name):
"""
Check to see if the ``bucket_name`` complies with the
restricted DNS naming conventions necessary to allow
access via virtual-hosting style.
Even though "." characters are perfectly valid in this DNS
naming scheme, we are going to punt on any name containing a
"." character because these will cause SSL cert validation
problems if we try to use virtual-hosting style addressing.
"""
if '.' in bucket_name:
return False
n = len(bucket_name)
if n < 3 or n > 63:
# Wrong length
return False
match = LABEL_RE.match(bucket_name)
if match is None or match.end() != len(bucket_name):
return False
return True
def fix_s3_host(request, signature_version, region_name,
default_endpoint_url=None, **kwargs):
"""
This handler looks at S3 requests just before they are signed.
If there is a bucket name on the path (true for everything except
ListAllBuckets) it checks to see if that bucket name conforms to
the DNS naming conventions. If it does, it alters the request to
use ``virtual hosting`` style addressing rather than ``path-style``
addressing.
"""
if request.context.get('use_global_endpoint', False):
default_endpoint_url = 's3.amazonaws.com'
try:
switch_to_virtual_host_style(
request, signature_version, default_endpoint_url)
except InvalidDNSNameError as e:
bucket_name = e.kwargs['bucket_name']
logger.debug('Not changing URI, bucket is not DNS compatible: %s',
bucket_name)
def switch_to_virtual_host_style(request, signature_version,
default_endpoint_url=None, **kwargs):
"""
This is a handler to force virtual host style s3 addressing no matter
the signature version (which is taken in consideration for the default
case). If the bucket is not DNS compatible an InvalidDNSName is thrown.
:param request: A AWSRequest object that is about to be sent.
:param signature_version: The signature version to sign with
:param default_endpoint_url: The endpoint to use when switching to a
virtual style. If None is supplied, the virtual host will be
constructed from the url of the request.
"""
if request.auth_path is not None:
# The auth_path has already been applied (this may be a
# retried request). We don't need to perform this
# customization again.
return
elif _is_get_bucket_location_request(request):
# For the GetBucketLocation response, we should not be using
# the virtual host style addressing so we can avoid any sigv4
# issues.
logger.debug("Request is GetBucketLocation operation, not checking "
"for DNS compatibility.")
return
parts = urlsplit(request.url)
request.auth_path = parts.path
path_parts = parts.path.split('/')
# Retrieve what the endpoint we will be prepending the bucket name to.
if default_endpoint_url is None:
default_endpoint_url = parts.netloc
if len(path_parts) > 1:
bucket_name = path_parts[1]
if not bucket_name:
# If the bucket name is empty we should not be checking for
# dns compatibility.
return
logger.debug('Checking for DNS compatible bucket for: %s',
request.url)
if check_dns_name(bucket_name):
# If the operation is on a bucket, the auth_path must be
# terminated with a '/' character.
if len(path_parts) == 2:
if request.auth_path[-1] != '/':
request.auth_path += '/'
path_parts.remove(bucket_name)
# At the very least the path must be a '/', such as with the
# CreateBucket operation when DNS style is being used. If this
# is not used you will get an empty path which is incorrect.
path = '/'.join(path_parts) or '/'
global_endpoint = default_endpoint_url
host = bucket_name + '.' + global_endpoint
new_tuple = (parts.scheme, host, path,
parts.query, '')
new_uri = urlunsplit(new_tuple)
request.url = new_uri
logger.debug('URI updated to: %s', new_uri)
else:
raise InvalidDNSNameError(bucket_name=bucket_name)
def _is_get_bucket_location_request(request):
return request.url.endswith('?location')
def instance_cache(func):
"""Method decorator for caching method calls to a single instance.
**This is not a general purpose caching decorator.**
In order to use this, you *must* provide an ``_instance_cache``
attribute on the instance.
This decorator is used to cache method calls. The cache is only
scoped to a single instance though such that multiple instances
will maintain their own cache. In order to keep things simple,
this decorator requires that you provide an ``_instance_cache``
attribute on your instance.
"""
func_name = func.__name__
@functools.wraps(func)
def _cache_guard(self, *args, **kwargs):
cache_key = (func_name, args)
if kwargs:
kwarg_items = tuple(sorted(kwargs.items()))
cache_key = (func_name, args, kwarg_items)
result = self._instance_cache.get(cache_key)
if result is not None:
return result
result = func(self, *args, **kwargs)
self._instance_cache[cache_key] = result
return result
return _cache_guard
def switch_host_s3_accelerate(request, operation_name, **kwargs):
"""Switches the current s3 endpoint with an S3 Accelerate endpoint"""
# Note that when registered the switching of the s3 host happens
# before it gets changed to virtual. So we are not concerned with ensuring
# that the bucket name is translated to the virtual style here and we
# can hard code the Accelerate endpoint.
parts = urlsplit(request.url).netloc.split('.')
parts = [p for p in parts if p in S3_ACCELERATE_WHITELIST]
endpoint = 'https://s3-accelerate.'
if len(parts) > 0:
endpoint += '.'.join(parts) + '.'
endpoint += 'amazonaws.com'
if operation_name in ['ListBuckets', 'CreateBucket', 'DeleteBucket']:
return
_switch_hosts(request, endpoint, use_new_scheme=False)
def switch_host_with_param(request, param_name):
"""Switches the host using a parameter value from a JSON request body"""
request_json = json.loads(request.data.decode('utf-8'))
if request_json.get(param_name):
new_endpoint = request_json[param_name]
_switch_hosts(request, new_endpoint)
def _switch_hosts(request, new_endpoint, use_new_scheme=True):
final_endpoint = _get_new_endpoint(
request.url, new_endpoint, use_new_scheme)
request.url = final_endpoint
def _get_new_endpoint(original_endpoint, new_endpoint, use_new_scheme=True):
new_endpoint_components = urlsplit(new_endpoint)
original_endpoint_components = urlsplit(original_endpoint)
scheme = original_endpoint_components.scheme
if use_new_scheme:
scheme = new_endpoint_components.scheme
final_endpoint_components = (
scheme,
new_endpoint_components.netloc,
original_endpoint_components.path,
original_endpoint_components.query,
''
)
final_endpoint = urlunsplit(final_endpoint_components)
logger.debug('Updating URI from %s to %s' % (
original_endpoint, final_endpoint))
return final_endpoint
def deep_merge(base, extra):
"""Deeply two dictionaries, overriding existing keys in the base.
:param base: The base dictionary which will be merged into.
:param extra: The dictionary to merge into the base. Keys from this
dictionary will take precedence.
"""
for key in extra:
# If the key represents a dict on both given dicts, merge the sub-dicts
if key in base and isinstance(base[key], dict)\
and isinstance(extra[key], dict):
deep_merge(base[key], extra[key])
continue
# Otherwise, set the key on the base to be the value of the extra.
base[key] = extra[key]
def hyphenize_service_id(service_id):
"""Translate the form used for event emitters.
:param service_id: The service_id to convert.
"""
return service_id.replace(' ', '-').lower()
class S3RegionRedirector(object):
def __init__(self, endpoint_bridge, client, cache=None):
self._endpoint_resolver = endpoint_bridge
self._cache = cache
if self._cache is None:
self._cache = {}
# This needs to be a weak ref in order to prevent memory leaks on
# python 2.6
self._client = weakref.proxy(client)
def register(self, event_emitter=None):
emitter = event_emitter or self._client.meta.events
emitter.register('needs-retry.s3', self.redirect_from_error)
emitter.register('before-call.s3', self.set_request_url)
emitter.register('before-parameter-build.s3',
self.redirect_from_cache)
def redirect_from_error(self, request_dict, response, operation, **kwargs):
"""
An S3 request sent to the wrong region will return an error that
contains the endpoint the request should be sent to. This handler
will add the redirect information to the signing context and then
redirect the request.
"""
if response is None:
# This could be none if there was a ConnectionError or other
# transport error.
return
if self._is_s3_accesspoint(request_dict.get('context', {})):
logger.debug(
'S3 request was previously to an accesspoint, not redirecting.'
)
return
if request_dict.get('context', {}).get('s3_redirected'):
logger.debug(
'S3 request was previously redirected, not redirecting.')
return
error = response[1].get('Error', {})
error_code = error.get('Code')
response_metadata = response[1].get('ResponseMetadata', {})
# We have to account for 400 responses because
# if we sign a Head* request with the wrong region,
# we'll get a 400 Bad Request but we won't get a
# body saying it's an "AuthorizationHeaderMalformed".
is_special_head_object = (
error_code in ['301', '400'] and
operation.name == 'HeadObject'
)
is_special_head_bucket = (
error_code in ['301', '400'] and
operation.name == 'HeadBucket' and
'x-amz-bucket-region' in response_metadata.get('HTTPHeaders', {})
)
is_wrong_signing_region = (
error_code == 'AuthorizationHeaderMalformed' and
'Region' in error
)
is_redirect_status = response[0] is not None and \
response[0].status_code in [301, 302, 307]
is_permanent_redirect = error_code == 'PermanentRedirect'
if not any([is_special_head_object, is_wrong_signing_region,
is_permanent_redirect, is_special_head_bucket,
is_redirect_status]):
return
bucket = request_dict['context']['signing']['bucket']
client_region = request_dict['context'].get('client_region')
new_region = self.get_bucket_region(bucket, response)
if new_region is None:
logger.debug(
"S3 client configured for region %s but the bucket %s is not "
"in that region and the proper region could not be "
"automatically determined." % (client_region, bucket))
return
logger.debug(
"S3 client configured for region %s but the bucket %s is in region"
" %s; Please configure the proper region to avoid multiple "
"unnecessary redirects and signing attempts." % (
client_region, bucket, new_region))
endpoint = self._endpoint_resolver.resolve('s3', new_region)
endpoint = endpoint['endpoint_url']
signing_context = {
'region': new_region,
'bucket': bucket,
'endpoint': endpoint
}
request_dict['context']['signing'] = signing_context
self._cache[bucket] = signing_context
self.set_request_url(request_dict, request_dict['context'])
request_dict['context']['s3_redirected'] = True
# Return 0 so it doesn't wait to retry
return 0
def get_bucket_region(self, bucket, response):
"""
There are multiple potential sources for the new region to redirect to,
but they aren't all universally available for use. This will try to
find region from response elements, but will fall back to calling
HEAD on the bucket if all else fails.
:param bucket: The bucket to find the region for. This is necessary if
the region is not available in the error response.
:param response: A response representing a service request that failed
due to incorrect region configuration.
"""
# First try to source the region from the headers.
service_response = response[1]
response_headers = service_response['ResponseMetadata']['HTTPHeaders']
if 'x-amz-bucket-region' in response_headers:
return response_headers['x-amz-bucket-region']
# Next, check the error body
region = service_response.get('Error', {}).get('Region', None)
if region is not None:
return region
# Finally, HEAD the bucket. No other choice sadly.
try:
response = self._client.head_bucket(Bucket=bucket)
headers = response['ResponseMetadata']['HTTPHeaders']
except ClientError as e:
headers = e.response['ResponseMetadata']['HTTPHeaders']
region = headers.get('x-amz-bucket-region', None)
return region
def set_request_url(self, params, context, **kwargs):
endpoint = context.get('signing', {}).get('endpoint', None)
if endpoint is not None:
params['url'] = _get_new_endpoint(params['url'], endpoint, False)
def redirect_from_cache(self, params, context, **kwargs):
"""
This handler retrieves a given bucket's signing context from the cache
and adds it into the request context.
"""
if self._is_s3_accesspoint(context):
return
bucket = params.get('Bucket')
signing_context = self._cache.get(bucket)
if signing_context is not None:
context['signing'] = signing_context
else:
context['signing'] = {'bucket': bucket}
def _is_s3_accesspoint(self, context):
return 's3_accesspoint' in context
class InvalidArnException(ValueError):
pass
class ArnParser(object):
def parse_arn(self, arn):
arn_parts = arn.split(':', 5)
if len(arn_parts) < 6:
raise InvalidArnException(
'Provided ARN: %s must be of the format: '
'arn:partition:service:region:account:resource' % arn
)
return {
'partition': arn_parts[1],
'service': arn_parts[2],
'region': arn_parts[3],
'account': arn_parts[4],
'resource': arn_parts[5],
}
class S3ArnParamHandler(object):
_ACCESSPOINT_RESOURCE_REGEX = re.compile(
r'^accesspoint[/:](?P<resource_name>.+)$'
)
_BLACKLISTED_OPERATIONS = [
'CreateBucket'
]
def __init__(self, arn_parser=None):
self._arn_parser = arn_parser
if arn_parser is None:
self._arn_parser = ArnParser()
def register(self, event_emitter):
event_emitter.register('before-parameter-build.s3', self.handle_arn)
def handle_arn(self, params, model, context, **kwargs):
if model.name in self._BLACKLISTED_OPERATIONS:
return
arn_details = self._get_arn_details_from_bucket_param(params)
if arn_details is None:
return
if arn_details['resource_type'] == 'accesspoint':
self._store_accesspoint(params, context, arn_details)
def _get_arn_details_from_bucket_param(self, params):
if 'Bucket' in params:
try:
arn = params['Bucket']
arn_details = self._arn_parser.parse_arn(arn)
self._add_resource_type_and_name(arn, arn_details)
return arn_details
except InvalidArnException:
pass
return None
def _add_resource_type_and_name(self, arn, arn_details):
match = self._ACCESSPOINT_RESOURCE_REGEX.match(arn_details['resource'])
if match:
arn_details['resource_type'] = 'accesspoint'
arn_details['resource_name'] = match.group('resource_name')
else:
raise UnsupportedS3ArnError(arn=arn)
def _store_accesspoint(self, params, context, arn_details):
# Ideally the access-point would be stored as a parameter in the
# request where the serializer would then know how to serialize it,
# but access-points are not modeled in S3 operations so it would fail
# validation. Instead, we set the access-point to the bucket parameter
# to have some value set when serializing the request and additional
# information on the context from the arn to use in forming the
# access-point endpoint.
params['Bucket'] = arn_details['resource_name']
context['s3_accesspoint'] = {
'name': arn_details['resource_name'],
'account': arn_details['account'],
'partition': arn_details['partition'],
'region': arn_details['region'],
}
class S3EndpointSetter(object):
_DEFAULT_PARTITION = 'aws'
_DEFAULT_DNS_SUFFIX = 'amazonaws.com'
def __init__(self, endpoint_resolver, region=None,
s3_config=None, endpoint_url=None, partition=None):
self._endpoint_resolver = endpoint_resolver
self._region = region
self._s3_config = s3_config
if s3_config is None:
self._s3_config = {}
self._endpoint_url = endpoint_url
self._partition = partition
if partition is None:
self._partition = self._DEFAULT_PARTITION
def register(self, event_emitter):
event_emitter.register('before-sign.s3', self.set_endpoint)
def set_endpoint(self, request, **kwargs):
if self._use_accesspoint_endpoint(request):
self._validate_accesspoint_supported(request)
region_name = self._resolve_region_for_accesspoint_endpoint(
request)
self._switch_to_accesspoint_endpoint(request, region_name)
return
if self._use_accelerate_endpoint:
switch_host_s3_accelerate(request=request, **kwargs)
if self._s3_addressing_handler:
self._s3_addressing_handler(request=request, **kwargs)
def _use_accesspoint_endpoint(self, request):
return 's3_accesspoint' in request.context
def _validate_accesspoint_supported(self, request):
if self._endpoint_url:
raise UnsupportedS3AccesspointConfigurationError(
msg=(
'Client cannot use a custom "endpoint_url" when '
'specifying an access-point ARN.'
)
)
if self._use_accelerate_endpoint:
raise UnsupportedS3AccesspointConfigurationError(
msg=(
'Client does not support s3 accelerate configuration '
'when and access-point ARN is specified.'
)
)
request_partion = request.context['s3_accesspoint']['partition']
if request_partion != self._partition:
raise UnsupportedS3AccesspointConfigurationError(
msg=(
'Client is configured for "%s" partition, but access-point'
' ARN provided is for "%s" partition. The client and '
' access-point partition must be the same.' % (
self._partition, request_partion)
)
)
def _resolve_region_for_accesspoint_endpoint(self, request):
if self._s3_config.get('use_arn_region', True):
accesspoint_region = request.context['s3_accesspoint']['region']
# If we are using the region from the access point,
# we will also want to make sure that we set it as the
# signing region as well
self._override_signing_region(request, accesspoint_region)
return accesspoint_region
return self._region
def _switch_to_accesspoint_endpoint(self, request, region_name):
original_components = urlsplit(request.url)
accesspoint_endpoint = urlunsplit((
original_components.scheme,
self._get_accesspoint_netloc(request.context, region_name),
self._get_accesspoint_path(
original_components.path, request.context),
original_components.query,
''
))
logger.debug(
'Updating URI from %s to %s' % (request.url, accesspoint_endpoint))
request.url = accesspoint_endpoint
def _get_accesspoint_netloc(self, request_context, region_name):
s3_accesspoint = request_context['s3_accesspoint']
accesspoint_netloc_components = [
'%s-%s' % (s3_accesspoint['name'], s3_accesspoint['account']),
's3-accesspoint'
]
if self._s3_config.get('use_dualstack_endpoint'):
accesspoint_netloc_components.append('dualstack')
accesspoint_netloc_components.extend(
[
region_name,
self._get_dns_suffix(region_name)
]
)
return '.'.join(accesspoint_netloc_components)
def _get_accesspoint_path(self, original_path, request_context):
# The Bucket parameter was substituted with the access-point name as
# some value was required in serializing the bucket name. Now that
# we are making the request directly to the access point, we will
# want to remove that access-point name from the path.
name = request_context['s3_accesspoint']['name']
# All S3 operations require at least a / in their path.
return original_path.replace('/' + name, '', 1) or '/'
def _get_dns_suffix(self, region_name):
resolved = self._endpoint_resolver.construct_endpoint(
's3', region_name)
dns_suffix = self._DEFAULT_DNS_SUFFIX
if resolved and 'dnsSuffix' in resolved:
dns_suffix = resolved['dnsSuffix']
return dns_suffix
def _override_signing_region(self, request, region_name):
signing_context = {
'region': region_name,
}
# S3SigV4Auth will use the context['signing']['region'] value to
# sign with if present. This is used by the Bucket redirector
# as well but we should be fine because the redirector is never
# used in combination with the accesspoint setting logic.
request.context['signing'] = signing_context
@CachedProperty
def _use_accelerate_endpoint(self):
# Enable accelerate if the configuration is set to to true or the
# endpoint being used matches one of the accelerate endpoints.
# Accelerate has been explicitly configured.
if self._s3_config.get('use_accelerate_endpoint'):
return True
# Accelerate mode is turned on automatically if an endpoint url is
# provided that matches the accelerate scheme.
if self._endpoint_url is None:
return False
# Accelerate is only valid for Amazon endpoints.
netloc = urlsplit(self._endpoint_url).netloc
if not netloc.endswith('amazonaws.com'):
return False
# The first part of the url should always be s3-accelerate.
parts = netloc.split('.')
if parts[0] != 's3-accelerate':
return False
# Url parts between 's3-accelerate' and 'amazonaws.com' which
# represent different url features.
feature_parts = parts[1:-2]
# There should be no duplicate url parts.
if len(feature_parts) != len(set(feature_parts)):
return False
# Remaining parts must all be in the whitelist.
return all(p in S3_ACCELERATE_WHITELIST for p in feature_parts)
@CachedProperty
def _addressing_style(self):
# Use virtual host style addressing if accelerate is enabled or if
# the given endpoint url is an accelerate endpoint.
if self._use_accelerate_endpoint:
return 'virtual'
# If a particular addressing style is configured, use it.
configured_addressing_style = self._s3_config.get('addressing_style')
if configured_addressing_style:
return configured_addressing_style
@CachedProperty
def _s3_addressing_handler(self):
# If virtual host style was configured, use it regardless of whether
# or not the bucket looks dns compatible.
if self._addressing_style == 'virtual':
logger.debug("Using S3 virtual host style addressing.")
return switch_to_virtual_host_style
# If path style is configured, no additional steps are needed. If
# endpoint_url was specified, don't default to virtual. We could
# potentially default provided endpoint urls to virtual hosted
# style, but for now it is avoided.
if self._addressing_style == 'path' or self._endpoint_url is not None:
logger.debug("Using S3 path style addressing.")
return None
logger.debug("Defaulting to S3 virtual host style addressing with "
"path style addressing fallback.")
# By default, try to use virtual style with path fallback.
return fix_s3_host
class ContainerMetadataFetcher(object):
TIMEOUT_SECONDS = 2
RETRY_ATTEMPTS = 3
SLEEP_TIME = 1
IP_ADDRESS = '169.254.170.2'
_ALLOWED_HOSTS = [IP_ADDRESS, 'localhost', '127.0.0.1']
def __init__(self, session=None, sleep=time.sleep):
if session is None:
session = botocore.httpsession.URLLib3Session(
timeout=self.TIMEOUT_SECONDS
)
self._session = session
self._sleep = sleep
def retrieve_full_uri(self, full_url, headers=None):
"""Retrieve JSON metadata from container metadata.
:type full_url: str
:param full_url: The full URL of the metadata service.
This should include the scheme as well, e.g
"http://localhost:123/foo"
"""
self._validate_allowed_url(full_url)
return self._retrieve_credentials(full_url, headers)
def _validate_allowed_url(self, full_url):
parsed = botocore.compat.urlparse(full_url)
is_whitelisted_host = self._check_if_whitelisted_host(
parsed.hostname)
if not is_whitelisted_host:
raise ValueError(
"Unsupported host '%s'. Can only "
"retrieve metadata from these hosts: %s" %
(parsed.hostname, ', '.join(self._ALLOWED_HOSTS)))
def _check_if_whitelisted_host(self, host):
if host in self._ALLOWED_HOSTS:
return True
return False
def retrieve_uri(self, relative_uri):
"""Retrieve JSON metadata from ECS metadata.
:type relative_uri: str
:param relative_uri: A relative URI, e.g "/foo/bar?id=123"
:return: The parsed JSON response.
"""
full_url = self.full_url(relative_uri)
return self._retrieve_credentials(full_url)
def _retrieve_credentials(self, full_url, extra_headers=None):
headers = {'Accept': 'application/json'}
if extra_headers is not None:
headers.update(extra_headers)
attempts = 0
while True:
try:
return self._get_response(
full_url, headers, self.TIMEOUT_SECONDS)
except MetadataRetrievalError as e:
logger.debug("Received error when attempting to retrieve "
"container metadata: %s", e, exc_info=True)
self._sleep(self.SLEEP_TIME)
attempts += 1
if attempts >= self.RETRY_ATTEMPTS:
raise
def _get_response(self, full_url, headers, timeout):
try:
AWSRequest = botocore.awsrequest.AWSRequest
request = AWSRequest(method='GET', url=full_url, headers=headers)
response = self._session.send(request.prepare())
response_text = response.content.decode('utf-8')
if response.status_code != 200:
raise MetadataRetrievalError(
error_msg=(
"Received non 200 response (%s) from ECS metadata: %s"
) % (response.status_code, response_text))
try:
return json.loads(response_text)
except ValueError:
error_msg = (
"Unable to parse JSON returned from ECS metadata services"
)
logger.debug('%s:%s', error_msg, response_text)
raise MetadataRetrievalError(error_msg=error_msg)
except RETRYABLE_HTTP_ERRORS as e:
error_msg = ("Received error when attempting to retrieve "
"ECS metadata: %s" % e)
raise MetadataRetrievalError(error_msg=error_msg)
def full_url(self, relative_uri):
return 'http://%s%s' % (self.IP_ADDRESS, relative_uri)
def get_environ_proxies(url):
if should_bypass_proxies(url):
return {}
else:
return getproxies()
def should_bypass_proxies(url):
"""
Returns whether we should bypass proxies or not.
"""
# NOTE: requests allowed for ip/cidr entries in no_proxy env that we don't
# support current as urllib only checks DNS suffix
# If the system proxy settings indicate that this URL should be bypassed,
# don't proxy.
# The proxy_bypass function is incredibly buggy on OS X in early versions
# of Python 2.6, so allow this call to fail. Only catch the specific
# exceptions we've seen, though: this call failing in other ways can reveal
# legitimate problems.
try:
if proxy_bypass(urlparse(url).netloc):
return True
except (TypeError, socket.gaierror):
pass
return False
def get_encoding_from_headers(headers, default='ISO-8859-1'):
"""Returns encodings from given HTTP Header Dict.
:param headers: dictionary to extract encoding from.
:param default: default encoding if the content-type is text
"""
content_type = headers.get('content-type')
if not content_type:
return None
content_type, params = cgi.parse_header(content_type)
if 'charset' in params:
return params['charset'].strip("'\"")
if 'text' in content_type:
return default
def calculate_md5(body, **kwargs):
if isinstance(body, (bytes, bytearray)):
binary_md5 = _calculate_md5_from_bytes(body)
else:
binary_md5 = _calculate_md5_from_file(body)
return base64.b64encode(binary_md5).decode('ascii')
def _calculate_md5_from_bytes(body_bytes):
md5 = get_md5(body_bytes)
return md5.digest()
def _calculate_md5_from_file(fileobj):
start_position = fileobj.tell()
md5 = get_md5()
for chunk in iter(lambda: fileobj.read(1024 * 1024), b''):
md5.update(chunk)
fileobj.seek(start_position)
return md5.digest()
def conditionally_calculate_md5(params, **kwargs):
"""Only add a Content-MD5 if the system supports it."""
headers = params['headers']
body = params['body']
if MD5_AVAILABLE and body and 'Content-MD5' not in headers:
md5_digest = calculate_md5(body, **kwargs)
params['headers']['Content-MD5'] = md5_digest
class FileWebIdentityTokenLoader(object):
def __init__(self, web_identity_token_path, _open=open):
self._web_identity_token_path = web_identity_token_path
self._open = _open
def __call__(self):
with self._open(self._web_identity_token_path) as token_file:
return token_file.read()
class SSOTokenLoader(object):
def __init__(self, cache=None):
if cache is None:
cache = {}
self._cache = cache
def _generate_cache_key(self, start_url):
return hashlib.sha1(start_url.encode('utf-8')).hexdigest()
def __call__(self, start_url):
cache_key = self._generate_cache_key(start_url)
try:
token = self._cache[cache_key]
return token['accessToken']
except KeyError:
logger.debug('Failed to load SSO token:', exc_info=True)
error_msg = (
'The SSO access token has either expired or is otherwise '
'invalid.'
)
raise SSOTokenLoadError(error_msg=error_msg)