Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
WriteTimeout,
map_exceptions,
)
from .._ssl import _normalize_server_hostname
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream

Expand Down Expand Up @@ -64,6 +65,8 @@ async def start_tls(
anyio.EndOfStream: ConnectError,
ssl.SSLError: ConnectError,
}
server_hostname = _normalize_server_hostname(server_hostname)

with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
Expand Down
4 changes: 4 additions & 0 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
WriteTimeout,
map_exceptions,
)
from .._ssl import _normalize_server_hostname
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream

Expand All @@ -41,6 +42,7 @@ def __init__(
self._sock = sock
self._incoming = ssl.MemoryBIO()
self._outgoing = ssl.MemoryBIO()
server_hostname = _normalize_server_hostname(server_hostname)

self.ssl_obj = ssl_context.wrap_bio(
incoming=self._incoming,
Expand Down Expand Up @@ -151,6 +153,8 @@ def start_tls(
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
server_hostname = _normalize_server_hostname(server_hostname)

with map_exceptions(exc_map):
try:
if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
Expand Down
3 changes: 3 additions & 0 deletions httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
WriteTimeout,
map_exceptions,
)
from .._ssl import _normalize_server_hostname
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream


Expand Down Expand Up @@ -62,6 +63,8 @@ async def start_tls(
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
}
server_hostname = _normalize_server_hostname(server_hostname)

ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
Expand Down
9 changes: 9 additions & 0 deletions httpcore/_ssl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import ssl

import certifi


def _normalize_server_hostname(server_hostname: str | None) -> str | None:
if server_hostname is None:
return None

return server_hostname.rstrip(".")


def default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
context.load_verify_locations(certifi.where())
Expand Down
41 changes: 41 additions & 0 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import socket
import typing

from httpcore._backends.sync import SyncStream
from httpcore._ssl import _normalize_server_hostname


class CapturingSSLContext:
def __init__(self) -> None:
self.server_hostname: str | None = None

def wrap_socket(
self, sock: socket.socket, server_hostname: str | None = None
) -> socket.socket:
self.server_hostname = server_hostname
return sock


def test_normalize_server_hostname() -> None:
assert _normalize_server_hostname(None) is None
assert _normalize_server_hostname("example.com") == "example.com"
assert _normalize_server_hostname("example.com.") == "example.com"


def test_sync_start_tls_normalizes_trailing_dot_hostname() -> None:
local, remote = socket.socketpair()
ssl_context = CapturingSSLContext()

try:
stream = SyncStream(local)
tls_stream = stream.start_tls(
typing.cast(typing.Any, ssl_context), server_hostname="example.com."
)

assert ssl_context.server_hostname == "example.com"
tls_stream.close()
finally:
local.close()
remote.close()
Loading