Skip to content

Commit

Permalink
Add ConcurrencyBackend.start_tls() (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Aug 24, 2019
1 parent a4b93b9 commit 1872ae8
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 7 deletions.
40 changes: 39 additions & 1 deletion httpx/concurrency/asyncio.py
Expand Up @@ -18,8 +18,8 @@
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseStream,
ConcurrencyBackend,
Expand Down Expand Up @@ -194,6 +194,44 @@ async def connect(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def start_tls(
self,
stream: BaseStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:

loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)

assert isinstance(stream, Stream)

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport

loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)

stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
Expand Down
9 changes: 9 additions & 0 deletions httpx/concurrency/base.py
Expand Up @@ -116,6 +116,15 @@ async def connect(
) -> BaseStream:
raise NotImplementedError() # pragma: no cover

async def start_tls(
self,
stream: BaseStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

Expand Down
4 changes: 2 additions & 2 deletions httpx/dispatch/asgi.py
@@ -1,10 +1,10 @@
import typing

from .base import AsyncDispatcher
from ..concurrency.base import ConcurrencyBackend
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
from .base import AsyncDispatcher


class ASGIDispatch(AsyncDispatcher):
Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/connection.py
Expand Up @@ -2,7 +2,6 @@
import ssl
import typing

from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
Expand All @@ -16,6 +15,7 @@
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
from .base import AsyncDispatcher
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection

Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/connection_pool.py
@@ -1,6 +1,5 @@
import typing

from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
Expand All @@ -13,6 +12,7 @@
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
from .base import AsyncDispatcher
from .connection import HTTPConnection

CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/threaded.py
@@ -1,4 +1,3 @@
from .base import AsyncDispatcher, Dispatcher
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import (
Expand All @@ -11,6 +10,7 @@
Response,
ResponseContent,
)
from .base import AsyncDispatcher, Dispatcher


class ThreadedDispatcher(AsyncDispatcher):
Expand Down
2 changes: 1 addition & 1 deletion httpx/dispatch/wsgi.py
@@ -1,9 +1,9 @@
import io
import typing

from .base import Dispatcher
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
from .base import Dispatcher


class WSGIDispatch(Dispatcher):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_concurrency.py
@@ -0,0 +1,31 @@
import sys

import pytest

from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig


@pytest.mark.xfail(
sys.version_info < (3, 7),
reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
)
@pytest.mark.asyncio
async def test_start_tls_on_socket_stream(https_server):
"""
See that the backend can make a connection without TLS then
start TLS on an existing connection.
"""
backend = AsyncioBackend()
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)

stream = await backend.connect("127.0.0.1", 8001, None, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is None

stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None

await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")

0 comments on commit 1872ae8

Please sign in to comment.