Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .._exceptions import ProxyError
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._trace import Trace
from ..backends.base import AsyncNetworkBackend
from .connection import AsyncHTTPConnection
from .connection_pool import AsyncConnectionPool
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(
max_connections: Optional[int] = 10,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: str = None,
uds: str = None,
Expand All @@ -69,6 +73,10 @@ def __init__(
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
Expand All @@ -84,6 +92,8 @@ def __init__(
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
Expand All @@ -107,6 +117,8 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)

Expand Down Expand Up @@ -177,6 +189,8 @@ def __init__(
ssl_context: ssl.SSLContext = None,
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
network_backend: AsyncNetworkBackend = None,
) -> None:
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
Expand All @@ -189,6 +203,8 @@ def __init__(
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = AsyncLock()
self._connected = False

Expand Down Expand Up @@ -224,16 +240,48 @@ async def handle_async_request(self, request: Request) -> Response:
raise ProxyError(msg)

stream = connect_response.extensions["network_stream"]
stream = await stream.start_tls(
ssl_context=self._ssl_context,
server_hostname=self._remote_origin.host.decode("ascii"),
timeout=timeout,

# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("connection.start_tls", request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)

# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection

self._connection = AsyncHTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)

self._connected = True
return await self._connection.handle_async_request(request)

Expand Down
64 changes: 56 additions & 8 deletions httpcore/_sync/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .._exceptions import ProxyError
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from ..backends.base import NetworkBackend
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(
max_connections: Optional[int] = 10,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: str = None,
uds: str = None,
Expand All @@ -69,6 +73,10 @@ def __init__(
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
Expand All @@ -84,6 +92,8 @@ def __init__(
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
Expand All @@ -107,6 +117,8 @@ def create_connection(self, origin: Origin) -> ConnectionInterface:
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)

Expand Down Expand Up @@ -177,6 +189,8 @@ def __init__(
ssl_context: ssl.SSLContext = None,
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
network_backend: NetworkBackend = None,
) -> None:
self._connection: ConnectionInterface = HTTPConnection(
Expand All @@ -189,6 +203,8 @@ def __init__(
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = Lock()
self._connected = False

Expand Down Expand Up @@ -224,16 +240,48 @@ def handle_request(self, request: Request) -> Response:
raise ProxyError(msg)

stream = connect_response.extensions["network_stream"]
stream = stream.start_tls(
ssl_context=self._ssl_context,
server_hostname=self._remote_origin.host.decode("ascii"),
timeout=timeout,

# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("connection.start_tls", request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)

# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import HTTP2Connection

self._connection = HTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)

self._connected = True
return self._connection.handle_request(request)

Expand Down
2 changes: 1 addition & 1 deletion scripts/coverage
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ fi

set -x

${PREFIX}coverage report --show-missing --skip-covered --fail-under=93
${PREFIX}coverage report --show-missing --skip-covered --fail-under=100
99 changes: 97 additions & 2 deletions tests/_async/test_http_proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import ssl

import hpack
import hyperframe.frame
import pytest

from httpcore import AsyncHTTPProxy, Origin, ProxyError
from httpcore.backends.mock import AsyncMockBackend
from httpcore.backends.base import AsyncNetworkStream
from httpcore.backends.mock import AsyncMockBackend, AsyncMockStream


@pytest.mark.anyio
Expand Down Expand Up @@ -64,7 +69,9 @@ async def test_proxy_tunneling():
"""
network_backend = AsyncMockBackend(
[
b"HTTP/1.1 200 OK\r\n" b"\r\n",
# The initial response to the proxy CONNECT
b"HTTP/1.1 200 OK\r\n\r\n",
# The actual response from the remote server
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
Expand Down Expand Up @@ -111,6 +118,94 @@ async def test_proxy_tunneling():
)


# We need to adapt the mock backend here slightly in order to deal
# with the proxy case. We do not want the initial connection to the proxy
# to indicate an HTTP/2 connection, but we do want it to indicate HTTP/2
# once the SSL upgrade has taken place.
class HTTP1ThenHTTP2Stream(AsyncMockStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str = None,
timeout: float = None,
) -> AsyncNetworkStream:
self._http2 = True
return self


class HTTP1ThenHTTP2Backend(AsyncMockBackend):
async def connect_tcp(
self, host: str, port: int, timeout: float = None, local_address: str = None
) -> AsyncNetworkStream:
return HTTP1ThenHTTP2Stream(list(self._buffer))


@pytest.mark.anyio
async def test_proxy_tunneling_http2():
"""
Send an HTTP/2 request via a proxy.
"""
network_backend = HTTP1ThenHTTP2Backend(
[
# The initial response to the proxy CONNECT
b"HTTP/1.1 200 OK\r\n\r\n",
# The actual response from the remote server
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
],
)

async with AsyncHTTPProxy(
proxy_url="http://localhost:8080/",
max_connections=10,
network_backend=network_backend,
http2=True,
) as proxy:
# Sending an intial request, which once complete will return to the pool, IDLE.
async with proxy.stream("GET", "https://example.com/") as response:
info = [repr(c) for c in proxy.connections]
assert info == [
"<AsyncTunnelHTTPConnection ['https://example.com:443', HTTP/2, ACTIVE, Request Count: 1]>"
]
await response.aread()

assert response.status == 200
assert response.content == b"Hello, world!"
info = [repr(c) for c in proxy.connections]
assert info == [
"<AsyncTunnelHTTPConnection ['https://example.com:443', HTTP/2, IDLE, Request Count: 1]>"
]
assert proxy.connections[0].is_idle()
assert proxy.connections[0].is_available()
assert not proxy.connections[0].is_closed()

# A connection on a tunneled proxy can only handle HTTPS requests to the same origin.
assert not proxy.connections[0].can_handle_request(
Origin(b"http", b"example.com", 80)
)
assert not proxy.connections[0].can_handle_request(
Origin(b"http", b"other.com", 80)
)
assert proxy.connections[0].can_handle_request(
Origin(b"https", b"example.com", 443)
)
assert not proxy.connections[0].can_handle_request(
Origin(b"https", b"other.com", 443)
)


@pytest.mark.anyio
async def test_proxy_tunneling_with_403():
"""
Expand Down
Loading