Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion elastic_transport/_node/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, ClassVar, List, NamedTuple, Optional, Tuple, Union

from .._models import ApiResponseMeta, HttpHeaders, NodeConfig
from .._utils import is_ipaddress
from .._version import __version__
from ..client_utils import DEFAULT, DefaultType

Expand Down Expand Up @@ -295,7 +296,7 @@ def ssl_context_from_node_config(node_config: NodeConfig) -> ssl.SSLContext:
# step if the user doesn't pass a preconfigured SSLContext.
if node_config.verify_certs:
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
ctx.check_hostname = not is_ipaddress(node_config.host)
else:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
Expand Down
48 changes: 47 additions & 1 deletion elastic_transport/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict
import re
from typing import Any, Dict, Union


def fixup_module_metadata(module_name: str, namespace: Dict[str, Any]) -> None:
Expand All @@ -31,3 +32,48 @@ def fix_one(obj: Any) -> None:
for objname in namespace["__all__"]:
obj = namespace[objname]
fix_one(obj)


IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
IPV4_RE = re.compile("^" + IPV4_PAT + "$")

HEX_PAT = "[0-9A-Fa-f]{1,4}"
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
_variations = [
# 6( h16 ":" ) ls32
"(?:%(hex)s:){6}%(ls32)s",
# "::" 5( h16 ":" ) ls32
"::(?:%(hex)s:){5}%(ls32)s",
# [ h16 ] "::" 4( h16 ":" ) ls32
"(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s",
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
"(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s",
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
"(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s",
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
"(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s",
# [ *4( h16 ":" ) h16 ] "::" ls32
"(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s",
# [ *5( h16 ":" ) h16 ] "::" h16
"(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s",
# [ *6( h16 ":" ) h16 ] "::"
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
]
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~"
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
BRACELESS_IPV6_ADDRZ_PAT = IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?"
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + BRACELESS_IPV6_ADDRZ_PAT + "$")


def is_ipaddress(hostname: Union[str, bytes]) -> bool:
"""Detects whether the hostname given is an IPv4 or IPv6 address.
Also detects IPv6 addresses with Zone IDs.
"""
# Copied from urllib3. License: MIT
if isinstance(hostname, bytes):
# IDN A-label bytes are ASCII compatible.
hostname = hostname.decode("ascii")
hostname = hostname.strip("[]")
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"pytest-cov",
"pytest-mock",
"pytest-asyncio",
"pytest-httpserver",
"trustme",
"mock",
"requests",
"aiohttp",
Expand Down
35 changes: 35 additions & 0 deletions tests/async_/test_httpserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import warnings

import pytest

from elastic_transport import AsyncTransport

pytestmark = pytest.mark.asyncio


async def test_simple_request(https_server_ip_node_config):
with warnings.catch_warnings():
warnings.simplefilter("error")

t = AsyncTransport([https_server_ip_node_config])

resp, data = await t.perform_request("GET", "/foobar")
assert resp.status == 200
assert data == {"foo": "bar"}
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import ssl

import pytest
import trustme
from pytest_httpserver import HTTPServer

from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig
from elastic_transport._node import NodeApiResponse
Expand Down Expand Up @@ -100,3 +102,28 @@ def elastic_transport_logging():
logger = logging.getLogger(f"elastic_transport.{name}")
for handler in logger.handlers[:]:
logger.removeHandler(handler)


@pytest.fixture(scope="session")
def https_server_ip_node_config(tmp_path_factory: pytest.TempPathFactory) -> NodeConfig:
ca = trustme.CA()
tmpdir = tmp_path_factory.mktemp("certs")
ca_cert_path = str(tmpdir / "ca.pem")
ca.cert_pem.write_to_path(ca_cert_path)

localhost_cert = ca.issue_cert("127.0.0.1")
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)

crt = localhost_cert.cert_chain_pems[0]
key = localhost_cert.private_key_pem
with crt.tempfile() as crt_file, key.tempfile() as key_file:
context.load_cert_chain(crt_file, key_file)

server = HTTPServer(ssl_context=context)
server.expect_request("/foobar").respond_with_json({"foo": "bar"})

server.start()
yield NodeConfig("https", "127.0.0.1", server.port, ca_certs=ca_cert_path)
server.clear()
if server.is_running():
server.stop()
22 changes: 21 additions & 1 deletion tests/node/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

import pytest

from elastic_transport import AiohttpHttpNode, RequestsHttpNode, Urllib3HttpNode
from elastic_transport import (
AiohttpHttpNode,
NodeConfig,
RequestsHttpNode,
Urllib3HttpNode,
)
from elastic_transport._node._base import ssl_context_from_node_config


@pytest.mark.parametrize(
Expand All @@ -26,3 +32,17 @@
def test_unknown_parameter(node_cls):
with pytest.raises(TypeError):
node_cls(unknown_option=1)


@pytest.mark.parametrize(
"host, check_hostname",
[
("127.0.0.1", False),
("::1", False),
("localhost", True),
],
)
def test_ssl_context_from_node_config(host, check_hostname):
node_config = NodeConfig("https", host, 443)
ctx = ssl_context_from_node_config(node_config)
assert ctx.check_hostname == check_hostname
34 changes: 34 additions & 0 deletions tests/test_httpserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import warnings

import pytest

from elastic_transport import Transport


@pytest.mark.parametrize("node_class", ["urllib3", "requests"])
def test_simple_request(node_class, https_server_ip_node_config):
with warnings.catch_warnings():
warnings.simplefilter("error")

t = Transport([https_server_ip_node_config], node_class=node_class)

resp, data = t.perform_request("GET", "/foobar")
assert resp.status == 200
assert data == {"foo": "bar"}
55 changes: 55 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

from elastic_transport._utils import is_ipaddress


@pytest.mark.parametrize(
"addr",
[
# IPv6
"::1",
"::",
"FE80::8939:7684:D84b:a5A4%251",
# IPv4
"127.0.0.1",
"8.8.8.8",
b"127.0.0.1",
# IPv6 w/ Zone IDs
"FE80::8939:7684:D84b:a5A4%251",
b"FE80::8939:7684:D84b:a5A4%251",
"FE80::8939:7684:D84b:a5A4%19",
b"FE80::8939:7684:D84b:a5A4%19",
],
)
def test_is_ipaddress(addr):
assert is_ipaddress(addr)


@pytest.mark.parametrize(
"addr",
[
"www.python.org",
b"www.python.org",
"v2.sg.media-imdb.com",
b"v2.sg.media-imdb.com",
],
)
def test_is_not_ipaddress(addr):
assert not is_ipaddress(addr)