diff --git a/google/auth/_default.py b/google/auth/_default.py index cedd8a99a..8dd8ac3b0 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -112,7 +112,9 @@ def _get_gae_credentials(): def _get_gce_credentials(): - if _metadata.ping(): + # TODO: Ping now requires a request argument. Figure out how to deal with + # that. + if _metadata.ping(request=None): return compute_engine.Credentials() diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 8768adfa6..4c0104a8c 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -25,7 +25,7 @@ from six.moves.urllib import parse as urlparse from google.auth import _helpers -from google.auth import transport +from google.auth import exceptions _METADATA_ROOT = 'http://metadata.google.internal/computeMetadata/v1/' @@ -42,10 +42,12 @@ _METADATA_DEFAULT_TIMEOUT = 3 -def ping(timeout=_METADATA_DEFAULT_TIMEOUT): +def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT): """Checks to see if the metadata server is available. Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. timeout (int): How long to wait for the metadata server to respond. Returns: @@ -58,8 +60,8 @@ def ping(timeout=_METADATA_DEFAULT_TIMEOUT): # the metadata resolution was particularly slow. The latter case is # "unlikely". try: - response = transport.request( - None, 'GET', _METADATA_IP_ROOT, headers=_METADATA_HEADERS, + response = request( + 'GET', _METADATA_IP_ROOT, headers=_METADATA_HEADERS, timeout=timeout, retries=False) return response.status == http_client.OK @@ -70,11 +72,12 @@ def ping(timeout=_METADATA_DEFAULT_TIMEOUT): return False -def get(http, path, root=_METADATA_ROOT, recursive=None): +def get(request, path, root=_METADATA_ROOT, recursive=None): """Fetch a resource from the metadata server. Args: - http (Any): The transport HTTP object. + request (google.auth.transport.Request): A callable used to make + HTTP requests. path (str): The resource to retrieve. For example, ``'instance/service-accounts/defualt'``. root (str): The full path to the metadata server root. @@ -88,13 +91,13 @@ def get(http, path, root=_METADATA_ROOT, recursive=None): returned as a string. Raises: - http_client.HTTPException: if an error occurred while retrieving - metadata. + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. """ url = urlparse.urljoin(root, path) url = _helpers.update_query(url, {'recursive': recursive}) - response = transport.request(http, 'GET', url, headers=_METADATA_HEADERS) + response = request('GET', url, headers=_METADATA_HEADERS) if response.status == http_client.OK: content = _helpers.from_bytes(response.data) @@ -103,17 +106,18 @@ def get(http, path, root=_METADATA_ROOT, recursive=None): else: return content else: - raise http_client.HTTPException( + raise exceptions.TransportError( 'Failed to retrieve {} from the Google Compute Engine' 'metadata service. Status: {} Response:\n{}'.format( - url, response.status, response.data)) + url, response.status, response.data), response) -def get_service_account_info(http, service_account='default'): +def get_service_account_info(request, service_account='default'): """Get information about a service account from the metadata server. Args: - http (Any): The transport HTTP object. + request (google.auth.transport.Request): A callable used to make + HTTP requests. service_account (str): The string 'default' or a service account email address. The determines which service account for which to acquire information. @@ -128,20 +132,21 @@ def get_service_account_info(http, service_account='default'): } Raises: - http_client.HTTPException: if an error occurred while retrieving - metadata. + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. """ return get( - http, + request, 'instance/service-accounts/{0}/'.format(service_account), recursive=True) -def get_service_account_token(http, service_account='default'): +def get_service_account_token(request, service_account='default'): """Get the OAuth 2.0 access token for a service account. Args: - http (Any): The transport HTTP object. + request (google.auth.transport.Request): A callable used to make + HTTP requests. service_account (str): The string 'default' or a service account email address. The determines which service account for which to acquire an access token. @@ -150,11 +155,11 @@ def get_service_account_token(http, service_account='default'): Union[str, datetime]: The access token and its expiration. Raises: - http_client.HTTPException: if an error occurred while retrieving - metadata. + google.auth.exceptions.TransportError: if an error occurred while + retrieving metadata. """ token_json = get( - http, + request, 'instance/service-accounts/{0}/token'.format(service_account)) token_expiry = _helpers.now() + datetime.timedelta( seconds=token_json['expires_in']) diff --git a/google/auth/credentials.py b/google/auth/credentials.py index d001665d8..c28fb0c16 100644 --- a/google/auth/credentials.py +++ b/google/auth/credentials.py @@ -72,11 +72,12 @@ def expired(self): return True @abc.abstractmethod - def refresh(self, http): + def refresh(self, request): """Refreshes the access token. Args: - http (Any): The transport http object. + request (google.auth.transport.Request): A callable used to make + HTTP requests. Raises: google.auth.exceptions.RefreshError: If the credentials could @@ -97,14 +98,15 @@ def apply(self, headers, token=None): headers[b'authorization'] = 'Bearer {}'.format( _helpers.from_bytes(token or self.token)) - def before_request(self, http, method, url, headers): + def before_request(self, request, method, url, headers): """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to apply the token to the authentication header. Args: - http (Any): The transport HTTP object. + request (google.auth.transport.Request): A callable used to make + HTTP requests. method (str): The request's HTTP method. url (str): The request's URI. headers (Mapping): The request's headers. @@ -113,7 +115,7 @@ def before_request(self, http, method, url, headers): # (Subclasses may use these arguments to ascertain information about # the http request.) if not self.valid: - self.refresh(http) + self.refresh(request) self.apply(headers) diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index 6ce029557..de6266c71 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -20,6 +20,11 @@ class GoogleAuthError(Exception): pass +class TransportError(Exception): + """Used to indicate an error occurred during an HTTP request.""" + pass + + class RefreshError(GoogleAuthError): """Used to indicate that an error occurred while refreshing the credentials' access token.""" diff --git a/google/auth/jwt.py b/google/auth/jwt.py index a6d9438e5..5fdc6e5fe 100644 --- a/google/auth/jwt.py +++ b/google/auth/jwt.py @@ -392,11 +392,11 @@ def _make_one_time_jwt(self, uri): token, _ = self._make_jwt(audience=audience) return token - def refresh(self, http): + def refresh(self, request): """Refreshes the access token. Args: - http (Any): The transport http object. + request (Any): Unused. """ # pylint: disable=unused-argument # (pylint doens't correctly recognize overriden methods.) @@ -414,7 +414,7 @@ def sign_bytes(self, message): """ return self._signer.sign(message) - def before_request(self, http, method, url, headers): + def before_request(self, request, method, url, headers): """Performs credential-specific before request logic. If an audience is specified it will refresh the credentials if @@ -423,7 +423,7 @@ def before_request(self, http, method, url, headers): authorization header in headers to the token. Args: - http (Any): The transport http object. + request (Any): Unused. method (str): The request's HTTP method. url (str): The request's URI. headers (Mapping): The request's headers. @@ -435,7 +435,7 @@ def before_request(self, http, method, url, headers): # there is a valid token and apply the auth headers. if self._audience: if not self.valid: - self.refresh(http) + self.refresh(request) self.apply(headers) # Otherwise, generate a one-time token using the URL # (without the query string and fragement) as the audience. diff --git a/google/auth/service_account.py b/google/auth/service_account.py index 6a80c4921..70243810a 100644 --- a/google/auth/service_account.py +++ b/google/auth/service_account.py @@ -81,7 +81,6 @@ from google.auth import credentials from google.auth import exceptions from google.auth import jwt -from google.auth import transport _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in sections _JWT_TOKEN_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:jwt-bearer' @@ -257,7 +256,7 @@ def _make_authorization_grant_assertion(self): return token @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, http): + def refresh(self, request): assertion = self._make_authorization_grant_assertion() body = urllib.parse.urlencode({ @@ -269,9 +268,8 @@ def refresh(self, http): 'content-type': 'application/x-www-form-urlencoded', } - response = transport.request( - http, method='POST', url=self._token_uri, headers=headers, - body=body) + response = request( + url=self._token_uri, method='POST', headers=headers, body=body) if response.status != http_client.OK: # Try to decode the response and extract details. diff --git a/google/auth/transport.py b/google/auth/transport.py deleted file mode 100644 index c5976fc5e..000000000 --- a/google/auth/transport.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2016 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Temporary tansport module.""" - -import urllib3 - - -def _default_http(): - return urllib3.PoolManager() - - -def request(http, *args, **kwargs): - """Make a request using the transport.""" - if http is None: - http = _default_http() - - return http.request(*args, **kwargs) diff --git a/google/auth/transport/__init__.py b/google/auth/transport/__init__.py new file mode 100644 index 000000000..e43147a77 --- /dev/null +++ b/google/auth/transport/__init__.py @@ -0,0 +1,82 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport - HTTP client library support. + +:mod:`google.auth` is designed to work with various HTTP client libraries such +as urllib3 and requests. In order to work across these libraries with different +interfaces some abstraction is needed. + +This module provides two interfaces that are implemented by transport adapters +to support HTTP libraries. :class:`Request` defines the interface expected by +:mod:`google.auth` to make requests. :class:`Response` defines the interface +for the return value of :class:`Request`. +""" + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class Response(object): + """HTTP Response data.""" + + @abc.abstractproperty + def status(self): + """int: The HTTP status code.""" + + @abc.abstractproperty + def headers(self): + """Mapping: The HTTP response headers.""" + + @abc.abstractproperty + def data(self): + """bytes: The response body.""" + + +@six.add_metaclass(abc.ABCMeta) +class Request(object): + """Interface for a callable that makes HTTP requests. + + Specific transport implementations should provide an implementation of + this that adapts their specific request / response API. + """ + + @abc.abstractmethod + def __call__(self, url, method='GET', body=None, headers=None, + timeout=None, **kwargs): + """Make an HTTP request. + + Args: + url (str): The URI to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (bytes): The payload / body in HTTP request. + headers (Mapping): Request headers. + timeout (Optional(int)): The number of seconds to wait for a + response from the server. If not specified or if None, the + transport-specific default timeout will be used. + kwargs: Additionally arguments passed on to the transport's + request method. + + Returns: + Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + # pylint: disable=redundant-returns-doc, missing-raises-doc + # (pylint doesn't play well with abstract docstrings.) + raise NotImplementedError('__call__ must be implemented.') diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py new file mode 100644 index 000000000..51d8af499 --- /dev/null +++ b/google/auth/transport/urllib3.py @@ -0,0 +1,67 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for urllib3.""" + +from __future__ import absolute_import + +import urllib3 +import urllib3.exceptions + +from google.auth import exceptions +from google.auth import transport + + +class Request(transport.Request): + """Urllib3 request adapter + + Args: + http (urllib3.requests.RequestMethods): An instance of any urllib3 + class that implements :cls:`~urllib3.requests.RequestMethods`, + usually :cls:`urllib3.PoolManager`. + """ + def __init__(self, http): + self.http = http + + def __call__(self, url, method='GET', body=None, headers=None, + timeout=None, **kwargs): + """ + Args: + url (str): The URI to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (bytes): The payload / body in HTTP request. + headers (Mapping): Request headers. + timeout (Optional(int)): The number of seconds to wait for a + response from the server. If not specified or if None, the + urllib3 default timeout will be used. + kwargs: Additional arguments passed throught to the underlying + urllib3 :meth:`urlopen` method. + + Returns: + Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + # Urllib3 uses a sentinel default value for timeout, so only set it if + # specified. + if timeout is not None: + kwargs['timeout'] = timeout + + try: + return self.http.request( + method, url, body=body, headers=headers, **kwargs) + except urllib3.exceptions.HTTPError as exc: + raise exceptions.TransportError(exc) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index 24964de69..2b49feee8 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -20,32 +20,33 @@ from six.moves import http_client from google.auth import _helpers +from google.auth import exceptions from google.auth.compute_engine import _metadata -HTTP_OBJECT = mock.Mock() PATH = 'instance/service-accounts/default' @pytest.fixture def mock_request(): - with mock.patch('google.auth.transport.request') as request_mock: - def set_response(data, status=http_client.OK, headers=None): - response = mock.Mock() - response.status = status - response.data = _helpers.to_bytes(data) - response.headers = headers or {} - request_mock.return_value = response - return request_mock - yield set_response + request_mock = mock.Mock() + + def set_response(data, status=http_client.OK, headers=None): + response = mock.Mock() + response.status = status + response.data = _helpers.to_bytes(data) + response.headers = headers or {} + request_mock.return_value = response + return request_mock + + yield set_response def test_ping_success(mock_request): request_mock = mock_request('') - assert _metadata.ping() + assert _metadata.ping(request_mock) request_mock.assert_called_once_with( - None, 'GET', _metadata._METADATA_IP_ROOT, headers=_metadata._METADATA_HEADERS, @@ -57,7 +58,7 @@ def test_ping_failure(mock_request): request_mock = mock_request('') request_mock.side_effect = Exception() - assert not _metadata.ping() + assert not _metadata.ping(request_mock) def test_get_success_json(mock_request): @@ -65,10 +66,9 @@ def test_get_success_json(mock_request): request_mock = mock_request( data, headers={'content-type': 'application/json'}) - result = _metadata.get(HTTP_OBJECT, PATH) + result = _metadata.get(request_mock, PATH) request_mock.assert_called_once_with( - HTTP_OBJECT, 'GET', _metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS) @@ -79,10 +79,9 @@ def test_get_success_text(mock_request): data = 'foobar' request_mock = mock_request(data, headers={'content-type': 'text/plain'}) - result = _metadata.get(HTTP_OBJECT, PATH) + result = _metadata.get(request_mock, PATH) request_mock.assert_called_once_with( - HTTP_OBJECT, 'GET', _metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS) @@ -93,13 +92,12 @@ def test_get_failure(mock_request): request_mock = mock_request( 'Metadata error', status=http_client.NOT_FOUND) - with pytest.raises(http_client.HTTPException) as excinfo: - _metadata.get(HTTP_OBJECT, PATH) + with pytest.raises(exceptions.TransportError) as excinfo: + _metadata.get(request_mock, PATH) assert excinfo.match(r'Metadata error') request_mock.assert_called_once_with( - HTTP_OBJECT, 'GET', _metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS) @@ -111,10 +109,9 @@ def test_get_service_account_token(now, mock_request): json.dumps({'access_token': 'token', 'expires_in': 500}), headers={'content-type': 'application/json'}) - token, expiry = _metadata.get_service_account_token(HTTP_OBJECT) + token, expiry = _metadata.get_service_account_token(request_mock) request_mock.assert_called_once_with( - HTTP_OBJECT, 'GET', _metadata._METADATA_ROOT + PATH + '/token', headers=_metadata._METADATA_HEADERS) @@ -127,10 +124,9 @@ def test_get_service_account_info(mock_request): json.dumps({'foo': 'bar'}), headers={'content-type': 'application/json'}) - info = _metadata.get_service_account_info(HTTP_OBJECT) + info = _metadata.get_service_account_info(request_mock) request_mock.assert_called_once_with( - HTTP_OBJECT, 'GET', _metadata._METADATA_ROOT + PATH + '/?recursive=True', headers=_metadata._METADATA_HEADERS) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 7e328d3c4..bfa6d9e6e 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -159,18 +159,17 @@ def test__make_authorization_grant_assertion_subject(self): payload = jwt.decode(token, PUBLIC_CERT_BYTES) assert payload['sub'] == 'user@example.com' - @mock.patch('google.auth.transport.request') - def test_refresh_success(self, request_mock): + def test_refresh_success(self): response = mock.Mock() response.status = http_client.OK response.data = json.dumps({ 'access_token': 'token', 'expires_in': 500 }).encode('utf-8') - request_mock.return_value = response + request_mock = mock.Mock(return_value=response) # Refresh credentials - self.credentials.refresh(None) + self.credentials.refresh(request_mock) # Check request data assert request_mock.called @@ -191,49 +190,46 @@ def test_refresh_success(self, request_mock): # expired) assert self.credentials.valid - @mock.patch('google.auth.transport.request') - def test_refresh_error(self, request_mock): + def test_refresh_error(self): response = mock.Mock() response.status = http_client.BAD_REQUEST response.data = json.dumps({ 'error': 'error', 'error_description': 'error description' }).encode('utf-8') - request_mock.return_value = response + request_mock = mock.Mock(return_value=response) with pytest.raises(exceptions.RefreshError) as excinfo: - self.credentials.refresh(None) + self.credentials.refresh(request_mock) assert excinfo.match(r'error: error description') - @mock.patch('google.auth.transport.request') - def test_refresh_error_non_json(self, request_mock): + def test_refresh_error_non_json(self): response = mock.Mock() response.status = http_client.BAD_REQUEST response.data = 'non-json error'.encode('utf-8') - request_mock.return_value = response + request_mock = mock.Mock(return_value=response) with pytest.raises(exceptions.RefreshError) as excinfo: - self.credentials.refresh(None) + self.credentials.refresh(request_mock) assert excinfo.match(r'non-json error') - @mock.patch('google.auth.transport.request') - def test_before_request_refreshes(self, request_mock): + def test_before_request_refreshes(self): response = mock.Mock() response.status = http_client.OK response.data = json.dumps({ 'access_token': 'token', 'expires_in': 500 }).encode('utf-8') - request_mock.return_value = response + request_mock = mock.Mock(return_value=response) # Credentials should start as invalid assert not self.credentials.valid # before_request should cause a refresh self.credentials.before_request( - mock.Mock(), 'GET', 'http://example.com?a=1#3', {}) + request_mock, 'GET', 'http://example.com?a=1#3', {}) # The refresh endpoint should've been called. assert request_mock.called diff --git a/tests/test_transport.py b/tests/test_transport.py deleted file mode 100644 index 72f29c0b2..000000000 --- a/tests/test_transport.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import mock -import urllib3 - -from google.auth import transport - - -def test__default_http(): - http = transport._default_http() - assert isinstance(http, urllib3.PoolManager) - - -def test_request(): - http = mock.Mock() - transport.request(http, 'a', b='c') - http.request.assert_called_with('a', b='c') - - -@mock.patch('google.auth.transport._default_http') -def test_request_no_http(default_http_mock): - http_mock = mock.Mock() - default_http_mock.return_value = http_mock - - transport.request(None, 'a', b='c') - - http_mock.request.assert_called_with('a', b='c')