diff --git a/elastic_transport/_node/_base.py b/elastic_transport/_node/_base.py index d26f54a..d939ded 100644 --- a/elastic_transport/_node/_base.py +++ b/elastic_transport/_node/_base.py @@ -22,6 +22,7 @@ from typing import Any, ClassVar, List, NamedTuple, Optional, Tuple, Union from .._models import ApiResponseMeta, HttpHeaders, NodeConfig +from .._utils import is_ipaddress from .._version import __version__ from ..client_utils import DEFAULT, DefaultType @@ -295,7 +296,7 @@ def ssl_context_from_node_config(node_config: NodeConfig) -> ssl.SSLContext: # step if the user doesn't pass a preconfigured SSLContext. if node_config.verify_certs: ctx.verify_mode = ssl.CERT_REQUIRED - ctx.check_hostname = True + ctx.check_hostname = not is_ipaddress(node_config.host) else: ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE diff --git a/elastic_transport/_utils.py b/elastic_transport/_utils.py index cc45b1d..7389651 100644 --- a/elastic_transport/_utils.py +++ b/elastic_transport/_utils.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +import re +from typing import Any, Dict, Union def fixup_module_metadata(module_name: str, namespace: Dict[str, Any]) -> None: @@ -31,3 +32,48 @@ def fix_one(obj: Any) -> None: for objname in namespace["__all__"]: obj = namespace[objname] fix_one(obj) + + +IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" +IPV4_RE = re.compile("^" + IPV4_PAT + "$") + +HEX_PAT = "[0-9A-Fa-f]{1,4}" +LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT) +_subs = {"hex": HEX_PAT, "ls32": LS32_PAT} +_variations = [ + # 6( h16 ":" ) ls32 + "(?:%(hex)s:){6}%(ls32)s", + # "::" 5( h16 ":" ) ls32 + "::(?:%(hex)s:){5}%(ls32)s", + # [ h16 ] "::" 4( h16 ":" ) ls32 + "(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s", + # [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 + "(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s", + # [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 + "(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s", + # [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 + "(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s", + # [ *4( h16 ":" ) h16 ] "::" ls32 + "(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s", + # [ *5( h16 ":" ) h16 ] "::" h16 + "(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s", + # [ *6( h16 ":" ) h16 ] "::" + "(?:(?:%(hex)s:){0,6}%(hex)s)?::", +] +IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" +UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~" +ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" +BRACELESS_IPV6_ADDRZ_PAT = IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?" +BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + BRACELESS_IPV6_ADDRZ_PAT + "$") + + +def is_ipaddress(hostname: Union[str, bytes]) -> bool: + """Detects whether the hostname given is an IPv4 or IPv6 address. + Also detects IPv6 addresses with Zone IDs. + """ + # Copied from urllib3. License: MIT + if isinstance(hostname, bytes): + # IDN A-label bytes are ASCII compatible. + hostname = hostname.decode("ascii") + hostname = hostname.strip("[]") + return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname)) diff --git a/setup.py b/setup.py index fee51cb..567c0b7 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,8 @@ "pytest-cov", "pytest-mock", "pytest-asyncio", + "pytest-httpserver", + "trustme", "mock", "requests", "aiohttp", diff --git a/tests/async_/test_httpserver.py b/tests/async_/test_httpserver.py new file mode 100644 index 0000000..03b1152 --- /dev/null +++ b/tests/async_/test_httpserver.py @@ -0,0 +1,35 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +import pytest + +from elastic_transport import AsyncTransport + +pytestmark = pytest.mark.asyncio + + +async def test_simple_request(https_server_ip_node_config): + with warnings.catch_warnings(): + warnings.simplefilter("error") + + t = AsyncTransport([https_server_ip_node_config]) + + resp, data = await t.perform_request("GET", "/foobar") + assert resp.status == 200 + assert data == {"foo": "bar"} diff --git a/tests/conftest.py b/tests/conftest.py index 26df560..129e1ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,8 @@ import ssl import pytest +import trustme +from pytest_httpserver import HTTPServer from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig from elastic_transport._node import NodeApiResponse @@ -100,3 +102,28 @@ def elastic_transport_logging(): logger = logging.getLogger(f"elastic_transport.{name}") for handler in logger.handlers[:]: logger.removeHandler(handler) + + +@pytest.fixture(scope="session") +def https_server_ip_node_config(tmp_path_factory: pytest.TempPathFactory) -> NodeConfig: + ca = trustme.CA() + tmpdir = tmp_path_factory.mktemp("certs") + ca_cert_path = str(tmpdir / "ca.pem") + ca.cert_pem.write_to_path(ca_cert_path) + + localhost_cert = ca.issue_cert("127.0.0.1") + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + + crt = localhost_cert.cert_chain_pems[0] + key = localhost_cert.private_key_pem + with crt.tempfile() as crt_file, key.tempfile() as key_file: + context.load_cert_chain(crt_file, key_file) + + server = HTTPServer(ssl_context=context) + server.expect_request("/foobar").respond_with_json({"foo": "bar"}) + + server.start() + yield NodeConfig("https", "127.0.0.1", server.port, ca_certs=ca_cert_path) + server.clear() + if server.is_running(): + server.stop() diff --git a/tests/node/test_base.py b/tests/node/test_base.py index 1e5a41f..f6f2bf3 100644 --- a/tests/node/test_base.py +++ b/tests/node/test_base.py @@ -17,7 +17,13 @@ import pytest -from elastic_transport import AiohttpHttpNode, RequestsHttpNode, Urllib3HttpNode +from elastic_transport import ( + AiohttpHttpNode, + NodeConfig, + RequestsHttpNode, + Urllib3HttpNode, +) +from elastic_transport._node._base import ssl_context_from_node_config @pytest.mark.parametrize( @@ -26,3 +32,17 @@ def test_unknown_parameter(node_cls): with pytest.raises(TypeError): node_cls(unknown_option=1) + + +@pytest.mark.parametrize( + "host, check_hostname", + [ + ("127.0.0.1", False), + ("::1", False), + ("localhost", True), + ], +) +def test_ssl_context_from_node_config(host, check_hostname): + node_config = NodeConfig("https", host, 443) + ctx = ssl_context_from_node_config(node_config) + assert ctx.check_hostname == check_hostname diff --git a/tests/test_httpserver.py b/tests/test_httpserver.py new file mode 100644 index 0000000..e7f7f76 --- /dev/null +++ b/tests/test_httpserver.py @@ -0,0 +1,34 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +import pytest + +from elastic_transport import Transport + + +@pytest.mark.parametrize("node_class", ["urllib3", "requests"]) +def test_simple_request(node_class, https_server_ip_node_config): + with warnings.catch_warnings(): + warnings.simplefilter("error") + + t = Transport([https_server_ip_node_config], node_class=node_class) + + resp, data = t.perform_request("GET", "/foobar") + assert resp.status == 200 + assert data == {"foo": "bar"} diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ea1cc41 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,55 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from elastic_transport._utils import is_ipaddress + + +@pytest.mark.parametrize( + "addr", + [ + # IPv6 + "::1", + "::", + "FE80::8939:7684:D84b:a5A4%251", + # IPv4 + "127.0.0.1", + "8.8.8.8", + b"127.0.0.1", + # IPv6 w/ Zone IDs + "FE80::8939:7684:D84b:a5A4%251", + b"FE80::8939:7684:D84b:a5A4%251", + "FE80::8939:7684:D84b:a5A4%19", + b"FE80::8939:7684:D84b:a5A4%19", + ], +) +def test_is_ipaddress(addr): + assert is_ipaddress(addr) + + +@pytest.mark.parametrize( + "addr", + [ + "www.python.org", + b"www.python.org", + "v2.sg.media-imdb.com", + b"v2.sg.media-imdb.com", + ], +) +def test_is_not_ipaddress(addr): + assert not is_ipaddress(addr)