diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index a140095e..09c651fc 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -14,6 +14,7 @@ WriteTimeout, map_exceptions, ) +from .._ssl import _normalize_server_hostname from .._utils import is_socket_readable from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -64,6 +65,8 @@ async def start_tls( anyio.EndOfStream: ConnectError, ssl.SSLError: ConnectError, } + server_hostname = _normalize_server_hostname(server_hostname) + with map_exceptions(exc_map): try: with anyio.fail_after(timeout): diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 4018a09c..ad9a0ed5 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -16,6 +16,7 @@ WriteTimeout, map_exceptions, ) +from .._ssl import _normalize_server_hostname from .._utils import is_socket_readable from .base import SOCKET_OPTION, NetworkBackend, NetworkStream @@ -41,6 +42,7 @@ def __init__( self._sock = sock self._incoming = ssl.MemoryBIO() self._outgoing = ssl.MemoryBIO() + server_hostname = _normalize_server_hostname(server_hostname) self.ssl_obj = ssl_context.wrap_bio( incoming=self._incoming, @@ -151,6 +153,8 @@ def start_tls( socket.timeout: ConnectTimeout, OSError: ConnectError, } + server_hostname = _normalize_server_hostname(server_hostname) + with map_exceptions(exc_map): try: if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index 6f53f5f2..326674c1 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -15,6 +15,7 @@ WriteTimeout, map_exceptions, ) +from .._ssl import _normalize_server_hostname from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -62,6 +63,8 @@ async def start_tls( trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } + server_hostname = _normalize_server_hostname(server_hostname) + ssl_stream = trio.SSLStream( self._stream, ssl_context=ssl_context, diff --git a/httpcore/_ssl.py b/httpcore/_ssl.py index c99c5a67..66514269 100644 --- a/httpcore/_ssl.py +++ b/httpcore/_ssl.py @@ -1,8 +1,17 @@ +from __future__ import annotations + import ssl import certifi +def _normalize_server_hostname(server_hostname: str | None) -> str | None: + if server_hostname is None: + return None + + return server_hostname.rstrip(".") + + def default_ssl_context() -> ssl.SSLContext: context = ssl.create_default_context() context.load_verify_locations(certifi.where()) diff --git a/tests/test_ssl.py b/tests/test_ssl.py new file mode 100644 index 00000000..60a4d06d --- /dev/null +++ b/tests/test_ssl.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import socket +import typing + +from httpcore._backends.sync import SyncStream +from httpcore._ssl import _normalize_server_hostname + + +class CapturingSSLContext: + def __init__(self) -> None: + self.server_hostname: str | None = None + + def wrap_socket( + self, sock: socket.socket, server_hostname: str | None = None + ) -> socket.socket: + self.server_hostname = server_hostname + return sock + + +def test_normalize_server_hostname() -> None: + assert _normalize_server_hostname(None) is None + assert _normalize_server_hostname("example.com") == "example.com" + assert _normalize_server_hostname("example.com.") == "example.com" + + +def test_sync_start_tls_normalizes_trailing_dot_hostname() -> None: + local, remote = socket.socketpair() + ssl_context = CapturingSSLContext() + + try: + stream = SyncStream(local) + tls_stream = stream.start_tls( + typing.cast(typing.Any, ssl_context), server_hostname="example.com." + ) + + assert ssl_context.server_hostname == "example.com" + tls_stream.close() + finally: + local.close() + remote.close()