diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 9b97783266..8fc625f1c0 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -18,8 +18,8 @@ from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import ( BaseBackgroundManager, - BasePoolSemaphore, BaseEvent, + BasePoolSemaphore, BaseQueue, BaseStream, ConcurrencyBackend, @@ -194,6 +194,44 @@ async def connect( stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) + async def start_tls( + self, + stream: BaseStream, + hostname: str, + ssl_context: ssl.SSLContext, + timeout: TimeoutConfig, + ) -> BaseStream: + + loop = self.loop + if not hasattr(loop, "start_tls"): # pragma: no cover + raise NotImplementedError( + "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" + ) + + assert isinstance(stream, Stream) + + stream_reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(stream_reader) + transport = stream.stream_writer.transport + + loop_start_tls = loop.start_tls # type: ignore + transport = await asyncio.wait_for( + loop_start_tls( + transport=transport, + protocol=protocol, + sslcontext=ssl_context, + server_hostname=hostname, + ), + timeout=timeout.connect_timeout, + ) + + stream_reader.set_transport(transport) + stream.stream_reader = stream_reader + stream.stream_writer = asyncio.StreamWriter( + transport=transport, protocol=protocol, reader=stream_reader, loop=loop + ) + return stream + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 9bfd54d4a4..bf2aed4f1f 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -116,6 +116,15 @@ async def connect( ) -> BaseStream: raise NotImplementedError() # pragma: no cover + async def start_tls( + self, + stream: BaseStream, + hostname: str, + ssl_context: ssl.SSLContext, + timeout: TimeoutConfig, + ) -> BaseStream: + raise NotImplementedError() # pragma: no cover + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index 6c1fc267da..c56d757c71 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -1,10 +1,10 @@ import typing -from .base import AsyncDispatcher -from ..concurrency.base import ConcurrencyBackend from ..concurrency.asyncio import AsyncioBackend +from ..concurrency.base import ConcurrencyBackend from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..models import AsyncRequest, AsyncResponse +from .base import AsyncDispatcher class ASGIDispatch(AsyncDispatcher): diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 0e9819cb98..87c2a6489e 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -2,7 +2,6 @@ import ssl import typing -from .base import AsyncDispatcher from ..concurrency.asyncio import AsyncioBackend from ..concurrency.base import ConcurrencyBackend from ..config import ( @@ -16,6 +15,7 @@ VerifyTypes, ) from ..models import AsyncRequest, AsyncResponse, Origin +from .base import AsyncDispatcher from .http2 import HTTP2Connection from .http11 import HTTP11Connection diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 0e14bf8354..eb990a9618 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -1,6 +1,5 @@ import typing -from .base import AsyncDispatcher from ..concurrency.asyncio import AsyncioBackend from ..concurrency.base import ConcurrencyBackend from ..config import ( @@ -13,6 +12,7 @@ VerifyTypes, ) from ..models import AsyncRequest, AsyncResponse, Origin +from .base import AsyncDispatcher from .connection import HTTPConnection CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] diff --git a/httpx/dispatch/threaded.py b/httpx/dispatch/threaded.py index 8176608729..7454a9e0ad 100644 --- a/httpx/dispatch/threaded.py +++ b/httpx/dispatch/threaded.py @@ -1,4 +1,3 @@ -from .base import AsyncDispatcher, Dispatcher from ..concurrency.base import ConcurrencyBackend from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..models import ( @@ -11,6 +10,7 @@ Response, ResponseContent, ) +from .base import AsyncDispatcher, Dispatcher class ThreadedDispatcher(AsyncDispatcher): diff --git a/httpx/dispatch/wsgi.py b/httpx/dispatch/wsgi.py index 0cbe1095e2..73a6fc1f5c 100644 --- a/httpx/dispatch/wsgi.py +++ b/httpx/dispatch/wsgi.py @@ -1,9 +1,9 @@ import io import typing -from .base import Dispatcher from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..models import Request, Response +from .base import Dispatcher class WSGIDispatch(Dispatcher): diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000000..870a592d01 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,31 @@ +import sys + +import pytest + +from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig + + +@pytest.mark.xfail( + sys.version_info < (3, 7), + reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()", +) +@pytest.mark.asyncio +async def test_start_tls_on_socket_stream(https_server): + """ + See that the backend can make a connection without TLS then + start TLS on an existing connection. + """ + backend = AsyncioBackend() + ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig()) + timeout = TimeoutConfig(5) + + stream = await backend.connect("127.0.0.1", 8001, None, timeout) + assert stream.is_connection_dropped() is False + assert stream.stream_writer.get_extra_info("cipher", default=None) is None + + stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout) + assert stream.is_connection_dropped() is False + assert stream.stream_writer.get_extra_info("cipher", default=None) is not None + + await stream.write(b"GET / HTTP/1.1\r\n\r\n") + assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")