From a786f7d915a81772072d3a6d744edea52f559449 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 15 Dec 2020 11:37:36 -0600 Subject: [PATCH] Add support for headers on Response and Exceptions --- elastic_transport/_version.py | 2 +- elastic_transport/compat.py | 7 ++ elastic_transport/connection/base.py | 6 +- elastic_transport/connection/http_requests.py | 4 +- elastic_transport/connection/http_urllib3.py | 27 ++---- elastic_transport/exceptions.py | 11 ++- elastic_transport/response.py | 70 ++++++++++++++-- elastic_transport/transport.py | 16 ++-- setup.py | 4 +- tests/test_connection.py | 3 + tests/test_httpbin.py | 82 +++++++++++++++++++ tests/test_response.py | 46 ++++++++++- 12 files changed, 237 insertions(+), 41 deletions(-) create mode 100644 tests/test_httpbin.py diff --git a/elastic_transport/_version.py b/elastic_transport/_version.py index 887f65b..1d005ac 100644 --- a/elastic_transport/_version.py +++ b/elastic_transport/_version.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. -__version__ = "0.1.0b0" +__version__ = "7.11.0" diff --git a/elastic_transport/compat.py b/elastic_transport/compat.py index 7ebb9ae..0e88e3d 100644 --- a/elastic_transport/compat.py +++ b/elastic_transport/compat.py @@ -22,8 +22,15 @@ except NameError: string_types = (str, bytes) +try: + from collections.abc import Mapping, MutableMapping +except ImportError: + from collections import Mapping, MutableMapping + __all__ = [ "urlparse", "urlencode", "string_types", + "Mapping", + "MutableMapping", ] diff --git a/elastic_transport/connection/base.py b/elastic_transport/connection/base.py index c0b10c0..2a18830 100644 --- a/elastic_transport/connection/base.py +++ b/elastic_transport/connection/base.py @@ -180,7 +180,7 @@ def log_request_fail( if response is not None: logger.debug("< %s", response) - def _raise_error(self, status, raw_data): + def _raise_error(self, status, headers, raw_data): """Locate appropriate exception and raise it. Attempts to decode the raw data as JSON for better usability. """ @@ -188,7 +188,9 @@ def _raise_error(self, status, raw_data): raw_data = json.loads(six.ensure_str(raw_data, "utf-8", "ignore")) except Exception: pass - raise HTTP_EXCEPTIONS.get(status, APIError)(message=raw_data, status=status) + raise HTTP_EXCEPTIONS.get(status, APIError)( + message=raw_data, status=status, headers=headers + ) def _gzip_compress(self, body): buf = io.BytesIO() diff --git a/elastic_transport/connection/http_requests.py b/elastic_transport/connection/http_requests.py index 894d2e9..108530b 100644 --- a/elastic_transport/connection/http_requests.py +++ b/elastic_transport/connection/http_requests.py @@ -176,7 +176,9 @@ def perform_request( status=response.status_code, response=raw_data, ) - self._raise_error(response.status_code, raw_data) + self._raise_error( + status=response.status_code, headers=response.headers, raw_data=raw_data + ) self.log_request_success( method=method, diff --git a/elastic_transport/connection/http_urllib3.py b/elastic_transport/connection/http_urllib3.py index 9b4b7ec..87dbb18 100644 --- a/elastic_transport/connection/http_urllib3.py +++ b/elastic_transport/connection/http_urllib3.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import ssl import time import warnings @@ -26,7 +25,7 @@ from ..compat import urlencode from ..exceptions import ConnectionError, ConnectionTimeout -from ..utils import DEFAULT +from ..utils import DEFAULT, normalize_headers from .base import Connection CA_CERTS = None @@ -39,25 +38,13 @@ pass -def create_ssl_context(**kwargs): - """ - A helper function around creating an SSL context - - https://docs.python.org/3/library/ssl.html#context-creation - - Accepts kwargs in the same manner as `create_default_context`. - """ - ctx = ssl.create_default_context(**kwargs) - return ctx - - class Urllib3HttpConnection(Connection): """ Default connection class using the `urllib3` library and the http protocol. :arg host: hostname of the node (default: localhost) - :arg port: port to use (integer, default: 9200) - :arg url_prefix: optional url prefix for elasticsearch + :arg port: port to use (integer) + :arg url_prefix: optional url prefix :arg timeout: default timeout in seconds (float, default: 10) :arg use_ssl: use ssl for the connection if `True` :arg verify_certs: whether to verify SSL certificates @@ -223,6 +210,7 @@ def perform_request( request_headers = self.headers.copy() request_headers.update(headers or ()) + request_headers = normalize_headers(request_headers) if self.http_compress and body: body = self._gzip_compress(body) @@ -236,6 +224,7 @@ def perform_request( headers=request_headers, **kw ) + response_headers = dict(response.headers) duration = time.time() - start raw_data = response.data.decode("utf-8", "surrogatepass") except Exception as e: @@ -262,7 +251,9 @@ def perform_request( status=response.status, response=raw_data, ) - self._raise_error(response.status, raw_data) + self._raise_error( + status=response.status, headers=response_headers, raw_data=raw_data + ) self.log_request_success( method=method, @@ -273,7 +264,7 @@ def perform_request( duration=duration, ) - return response.status, response.getheaders(), raw_data + return response.status, response_headers, raw_data def close(self): """ diff --git a/elastic_transport/exceptions.py b/elastic_transport/exceptions.py index 28b0681..90efb34 100644 --- a/elastic_transport/exceptions.py +++ b/elastic_transport/exceptions.py @@ -17,6 +17,8 @@ from six import add_metaclass, python_2_unicode_compatible +from .response import Headers + HTTP_EXCEPTIONS = {} @@ -38,17 +40,22 @@ class TransportError(Exception): most recently raised (index=0) to least recently raised (index=N) If an HTTP status code is available with the error it - will be stored under 'status'. + will be stored under 'status'. If HTTP headers are available + they are stored under 'headers'. """ status = None - def __init__(self, message, errors=(), status=None): + def __init__(self, message, errors=(), status=None, headers=None): super(TransportError, self).__init__(message) self.errors = tuple(errors) self.message = message if status is not None: self.status = status + if headers is not None: + self.headers = Headers(headers) + else: + self.headers = None def __repr__(self): parts = [repr(self.message)] diff --git a/elastic_transport/response.py b/elastic_transport/response.py index 35be86e..435e222 100644 --- a/elastic_transport/response.py +++ b/elastic_transport/response.py @@ -15,13 +15,71 @@ # specific language governing permissions and limitations # under the License. +from .compat import Mapping + + +class Headers(Mapping): + """HTTP headers""" + + def __init__(self, initial=None): + self._internal = {} + if initial: + for key, val in dict(initial).items(): + self._internal[self._normalize_key(key)] = (key, val) + + def __getitem__(self, item): + return self._internal[self._normalize_key(item)][1] + + def __eq__(self, other): + if isinstance(other, Mapping): + return dict(self.items()) == dict(other.items()) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Mapping): + return dict(self.items()) != dict(other.items()) + return NotImplemented + + def __iter__(self): + return iter(self.keys()) + + def __len__(self): + return len(self._internal) + + def __contains__(self, item): + return self._normalize_key(item) in self._internal + + def __repr__(self): + return repr(dict(self.items())) + + def __str__(self): + return str(dict(self.items())) + + def get(self, key, default=None): + return self._internal.get(self._normalize_key(key), (None, default))[1] + + def keys(self): + return [key for _, (key, _) in self._internal.items()] + + def values(self): + return [val for _, (_, val) in self._internal.items()] + + def items(self): + return [(key, val) for _, (key, val) in self._internal.items()] + + def copy(self): + return dict(self.items()) + + def _normalize_key(self, key): + return key.lower() if hasattr(key, "lower") else key + class Response(object): """HTTP response""" - def __init__(self, headers, status, body): - self.headers = headers + def __init__(self, status, headers, body): self.status = status + self.headers = Headers(headers) self.body = body def __repr__(self): @@ -67,12 +125,12 @@ def __ne__(self, other): class DictResponse(Response, dict): - def __init__(self, headers, status, body): - Response.__init__(self, headers, status, body) + def __init__(self, status, headers, body): + Response.__init__(self, status=status, headers=headers, body=body) dict.__init__(self, body) class ListResponse(Response, list): - def __init__(self, headers, status, body): - Response.__init__(self, headers, status, body) + def __init__(self, status, headers, body): + Response.__init__(self, status=status, headers=headers, body=body) list.__init__(self, body) diff --git a/elastic_transport/transport.py b/elastic_transport/transport.py index 7518ffc..e08bc69 100644 --- a/elastic_transport/transport.py +++ b/elastic_transport/transport.py @@ -240,7 +240,7 @@ def perform_request( connection = self.get_connection() try: - status, headers_response, data = connection.perform_request( + resp_status, resp_headers, data = connection.perform_request( method, path, params, @@ -253,7 +253,7 @@ def perform_request( if method == "HEAD" and e.status == 404: return Response( status=404, - headers={}, + headers=e.headers, body=False, ) @@ -289,14 +289,14 @@ def perform_request( if method == "HEAD": return Response( - status=status, - headers=headers_response, - body=200 <= status < 300, + status=resp_status, + headers=resp_headers, + body=200 <= resp_status < 300, ) if data: data = self.deserializer.loads( - data, headers_response.get("content-type") + data, resp_headers.get("content-type") ) # After the body is deserialized put the data @@ -307,8 +307,8 @@ def perform_request( elif isinstance(data, dict): response_cls = DictResponse return response_cls( - status=status, - headers=headers_response, + status=resp_status, + headers=resp_headers, body=data, ) diff --git a/setup.py b/setup.py index cab33c3..d472643 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ }, packages=packages, install_requires=[ - "urllib3>=1.21.1", + "urllib3>=1.21.1, <2", "six>=1.12", "certifi", ], @@ -57,7 +57,7 @@ "develop": ["pytest", "pytest-cov", "pytest-mock", "mock", "requests"], }, classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Developers", "Operating System :: OS Independent", diff --git a/tests/test_connection.py b/tests/test_connection.py index 6147749..87a070a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -140,6 +140,7 @@ def test_timeout_is_10_seconds_by_default(self): with patch.object(conn.pool, "urlopen") as pool_urlopen: resp = Mock() resp.status = 200 + resp.headers = {} pool_urlopen.return_value = resp conn.perform_request("GET", "/") @@ -158,6 +159,7 @@ def test_timeout_override_default(self, request_timeout): with patch.object(conn.pool, "urlopen") as pool_urlopen: resp = Mock() resp.status = 200 + resp.headers = {} pool_urlopen.return_value = resp conn.perform_request("GET", "/", request_timeout=request_timeout) @@ -244,6 +246,7 @@ def test_failed_request_logs(self, logger): resp = Mock() resp.data = b'{"answer":42}' resp.status = 500 + resp.headers = {} pool_urlopen.return_value = resp with pytest.raises(TransportError) as e: diff --git a/tests/test_httpbin.py b/tests/test_httpbin.py new file mode 100644 index 0000000..0141c76 --- /dev/null +++ b/tests/test_httpbin.py @@ -0,0 +1,82 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from elastic_transport import NotFoundError, Transport + + +@pytest.mark.parametrize("connection_class", ["urllib3", "requests"]) +def test_simple_request(connection_class): + t = Transport("https://httpbin.org", connection_class=connection_class) + resp = t.perform_request( + "GET", + "/anything", + headers={"Custom": "headeR"}, + params={"Query": "String"}, + body={"JSON": "body"}, + ) + assert resp.status == 200 + assert resp["method"] == "GET" + assert resp["url"] == "https://httpbin.org/anything?Query=String" + assert resp["args"] == {"Query": "String"} + assert resp["data"] == '{"JSON":"body"}' + assert resp["json"] == {"JSON": "body"} + + request_headers = { + "Content-Type": "application/json", + "Content-Length": "15", + "Custom": "headeR", + "Host": "httpbin.org", + } + assert all(v == resp["headers"][k] for k, v in request_headers.items()) + + assert resp.headers["content-type"] == "application/json" + assert resp.headers["Content-Type"] == "application/json" + + +@pytest.mark.parametrize("connection_class", ["urllib3", "requests"]) +@pytest.mark.parametrize("status", [200, 404]) +def test_head_request(connection_class, status): + t = Transport("https://httpbin.org", connection_class=connection_class) + resp = t.perform_request( + "HEAD", + "/status/%d" % status, + headers={"Custom": "headeR"}, + params={"Query": "String"}, + body={"JSON": "body"}, + ) + assert resp.status == status + assert bool(resp) is (status == 200) + + +@pytest.mark.parametrize("connection_class", ["urllib3", "requests"]) +def test_get_404_request(connection_class): + t = Transport("https://httpbin.org", connection_class=connection_class) + with pytest.raises(NotFoundError) as e: + t.perform_request( + "GET", + "/status/404", + headers={"Custom": "headeR"}, + params={"Query": "String"}, + body={"JSON": "body"}, + ) + + resp = e.value + assert resp.status == 404 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.headers["Content-Type"] == "text/html; charset=utf-8" diff --git a/tests/test_response.py b/tests/test_response.py index 21e67f3..ed5a52e 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -19,7 +19,7 @@ import pytest -from elastic_transport.response import DictResponse, ListResponse, Response +from elastic_transport.response import DictResponse, Headers, ListResponse, Response resp_dict = DictResponse(status=200, headers={}, body={"key": "val"}) resp_list = ListResponse(status=404, headers={}, body=["a", 2, 3, {"k": "v"}]) @@ -47,6 +47,7 @@ def test_response_not_equals(): def test_response_attributes(): assert resp_bool.status == 200 assert resp_bool.body is False + assert resp_bool.headers == {} def test_response_len(): @@ -109,3 +110,46 @@ def test_response_json(): def test_response_instance_checks(): assert isinstance(resp_list, list) assert isinstance(resp_dict, dict) + + +def test_headers_empty(): + assert isinstance(resp_dict.headers, Headers) + + headers = Headers() + + assert len(headers) == 0 + with pytest.raises(KeyError) as e: + headers["Content-Length"] + assert str(e.value) == "'content-length'" + with pytest.raises(KeyError) as e: + headers[None] + assert str(e.value) == "None" + assert list(headers.items()) == [] + + +@pytest.mark.parametrize( + "initial", [{"Content-Length": "0"}, [("Content-Length", "0")]] +) +@pytest.mark.parametrize( + "get_key", ["content-length", "Content-Length", "cOntent-length"] +) +def test_headers_with_items(initial, get_key): + headers = Headers(initial) + + assert len(headers) == 1 + assert list(headers.keys()) == ["Content-Length"] + assert list(headers.values()) == ["0"] + assert list(headers.items()) == [("Content-Length", "0")] + assert headers.get(get_key) == "0" + assert headers[get_key] == "0" + + other_headers = dict(initial) + assert headers == other_headers + assert headers == Headers(other_headers) + + other_headers["content-type"] = "application/json" + assert headers != other_headers + assert headers != Headers(other_headers) + + assert isinstance(headers.copy(), dict) + assert headers.copy() == headers