Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tunnel proxy: HTTP requests only #57

Merged
merged 18 commits into from Apr 29, 2020
Merged
41 changes: 27 additions & 14 deletions httpcore/_async/connection.py
@@ -1,7 +1,7 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union

from .._backends.auto import AsyncLock, AutoBackend
from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .base import (
AsyncByteStream,
Expand All @@ -15,11 +15,16 @@

class AsyncHTTPConnection(AsyncHTTPTransport):
def __init__(
self, origin: Origin, http2: bool = False, ssl_context: SSLContext = None,
self,
origin: Origin,
http2: bool = False,
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -48,14 +53,11 @@ async def request(
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], AsyncByteStream]:
assert url[:3] == self.origin

async with self.request_lock:
if self.state == ConnectionState.PENDING:
try:
await self._connect(timeout)
except Exception:
self.connect_failed = True
raise
if not self.socket:
self.socket = await self._open_socket(timeout)
self._create_connection(self.socket)
elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
pass
elif self.state == ConnectionState.ACTIVE and self.is_http2:
Expand All @@ -66,20 +68,30 @@ async def request(
assert self.connection is not None
return await self.connection.request(method, url, headers, stream, timeout)

async def _connect(self, timeout: TimeoutDict = None) -> None:
async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
socket = await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
try:
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise

def _create_connection(self, socket: AsyncSocketStream) -> None:
http_version = socket.get_http_version()
if http_version == "HTTP/2":
self.is_http2 = True
self.connection = AsyncHTTP2Connection(socket=socket, backend=self.backend)
self.connection = AsyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self.is_http11 = True
self.connection = AsyncHTTP11Connection(socket=socket)
self.connection = AsyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
)

@property
def state(self) -> ConnectionState:
Expand All @@ -99,3 +111,4 @@ def mark_as_ready(self) -> None:
async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
await self.connection.start_tls(hostname, timeout)
self.socket = self.connection.socket
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion httpcore/_async/http11.py
Expand Up @@ -123,7 +123,7 @@ async def _receive_response_data(
event = await self._receive_event(timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, h11.EndOfMessage):
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved
break

async def _receive_event(self, timeout: TimeoutDict) -> H11Event:
Expand Down
54 changes: 25 additions & 29 deletions httpcore/_async/http_proxy.py
Expand Up @@ -8,13 +8,6 @@
from .connection_pool import AsyncConnectionPool, ResponseByteStream


async def read_body(stream: AsyncByteStream) -> bytes:
try:
return b"".join([chunk async for chunk in stream])
finally:
await stream.aclose()


class AsyncHTTPProxy(AsyncConnectionPool):
"""
A connection pool for making HTTP requests via an HTTP proxy.
Expand All @@ -26,7 +19,8 @@ class AsyncHTTPProxy(AsyncConnectionPool):
* **proxy_headers** - `Optional[List[Tuple[bytes, bytes]]]` - A list of
proxy headers to include.
* **proxy_mode** - `str` - A proxy mode to operate in. May be "DEFAULT",
"FORWARD_ONLY", or "TUNNEL_ONLY".
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved
"FORWARD_ONLY", or "TUNNEL_ONLY". "DEFAULT" is identical to "FORWARD_ONLY"
but is kept for backward compatibility purposes.
* **ssl_context** - `Optional[SSLContext]` - An SSL context to use for
verifying connections.
* **max_connections** - `Optional[int]` - The maximum number of concurrent
Expand All @@ -39,8 +33,8 @@ class AsyncHTTPProxy(AsyncConnectionPool):
def __init__(
self,
proxy_origin: Origin,
proxy_mode: str,
proxy_headers: Headers = None,
proxy_mode: str = "DEFAULT",
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved
ssl_context: SSLContext = None,
max_connections: int = None,
max_keepalive: int = None,
Expand Down Expand Up @@ -140,47 +134,49 @@ async def _tunnel_request(
connection = await self._get_connection_from_pool(origin)

if connection is None:
connection = AsyncHTTPConnection(
origin=origin, http2=False, ssl_context=self._ssl_context,
# First, create a connection to the proxy server
proxy_connection = AsyncHTTPConnection(
origin=self.proxy_origin, http2=False, ssl_context=self._ssl_context,
)
async with self._thread_lock:
self._connections.setdefault(origin, set())
self._connections[origin].add(connection)

# Establish the connection by issuing a CONNECT request...
# Issue a CONNECT request...

# CONNECT www.example.org:80 HTTP/1.1
# [proxy-headers]
target = b"%b:%d" % (url[1], url[2])
connect_url = self.proxy_origin + (target,)
connect_headers = self.proxy_headers
proxy_response = await connection.request(
b"CONNECT", connect_url, headers=connect_headers, timeout=timeout
proxy_response = await proxy_connection.request(
b"CONNECT", connect_url, headers=self.proxy_headers, timeout=timeout
)
proxy_status_code = proxy_response[1]
proxy_reason_phrase = proxy_response[2]
proxy_stream = proxy_response[4]

# Ingest any request body.
await read_body(proxy_stream)
# Read the response data without closing the socket
async for _ in proxy_stream:
pass

# If the proxy responds with an error, then drop the connection
# from the pool, and raise an exception.
# See if the tunnel was successfully established.
if proxy_status_code < 200 or proxy_status_code > 299:
async with self._thread_lock:
self._connections[connection.origin].remove(connection)
if not self._connections[connection.origin]:
del self._connections[connection.origin]
msg = "%d %s" % (proxy_status_code, proxy_reason_phrase.decode("ascii"))
raise ProxyError(msg)

# Upgrade to TLS.
await connection.start_tls(target, timeout)
# The CONNECT request is successful, so we have now SWITCHED PROTOCOLS.
# This means the proxy connection is now unusable, and we must create
# a new one for regular requests, making sure to use the same socket to
# retain the tunnel.
connection = AsyncHTTPConnection(
origin=origin,
http2=False,
ssl_context=self._ssl_context,
socket=proxy_connection.socket,
)
await self._add_to_pool(connection)

# Once the connection has been established we can send requests on
# it as normal.
response = await connection.request(
method, url, headers=headers, stream=stream, timeout=timeout
method, url, headers=headers, stream=stream, timeout=timeout,
)
wrapped_stream = ResponseByteStream(
response[4], connection=connection, callback=self._response_closed
Expand Down
6 changes: 3 additions & 3 deletions httpcore/_backends/asyncio.py
Expand Up @@ -107,9 +107,9 @@ async def start_tls(

transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
transport,
protocol,
ssl_context,
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
server_hostname=hostname.decode("ascii"),
),
timeout=timeout.get("connect"),
Expand Down
41 changes: 27 additions & 14 deletions httpcore/_sync/connection.py
@@ -1,7 +1,7 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union

from .._backends.auto import SyncLock, SyncBackend
from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .base import (
SyncByteStream,
Expand All @@ -15,11 +15,16 @@

class SyncHTTPConnection(SyncHTTPTransport):
def __init__(
self, origin: Origin, http2: bool = False, ssl_context: SSLContext = None,
self,
origin: Origin,
http2: bool = False,
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -48,14 +53,11 @@ def request(
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], SyncByteStream]:
assert url[:3] == self.origin

with self.request_lock:
if self.state == ConnectionState.PENDING:
try:
self._connect(timeout)
except Exception:
self.connect_failed = True
raise
if not self.socket:
self.socket = self._open_socket(timeout)
self._create_connection(self.socket)
elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
pass
elif self.state == ConnectionState.ACTIVE and self.is_http2:
Expand All @@ -66,20 +68,30 @@ def request(
assert self.connection is not None
return self.connection.request(method, url, headers, stream, timeout)

def _connect(self, timeout: TimeoutDict = None) -> None:
def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
socket = self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
try:
return self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise

def _create_connection(self, socket: SyncSocketStream) -> None:
http_version = socket.get_http_version()
if http_version == "HTTP/2":
self.is_http2 = True
self.connection = SyncHTTP2Connection(socket=socket, backend=self.backend)
self.connection = SyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
self.is_http11 = True
self.connection = SyncHTTP11Connection(socket=socket)
self.connection = SyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
)

@property
def state(self) -> ConnectionState:
Expand All @@ -99,3 +111,4 @@ def mark_as_ready(self) -> None:
def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
self.connection.start_tls(hostname, timeout)
self.socket = self.connection.socket
2 changes: 1 addition & 1 deletion httpcore/_sync/http11.py
Expand Up @@ -123,7 +123,7 @@ def _receive_response_data(
event = self._receive_event(timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, h11.EndOfMessage):
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
break

def _receive_event(self, timeout: TimeoutDict) -> H11Event:
Expand Down