Skip to content

Commit

Permalink
Remove DEFAULT mode, renamed to FORWARD and TUNNEL modes
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeray Diaz Diaz committed Apr 13, 2020
1 parent 0e4ae38 commit 7066bf0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 94 deletions.
48 changes: 6 additions & 42 deletions httpcore/_async/http_proxy.py
Expand Up @@ -18,8 +18,8 @@ class AsyncHTTPProxy(AsyncConnectionPool):
service as a 3-tuple of (scheme, host, port).
* **proxy_headers** - `Optional[List[Tuple[bytes, bytes]]]` - A list of
proxy headers to include.
* **proxy_mode** - `str` - A proxy mode to operate in. May be "DEFAULT",
"FORWARD_ONLY", or "TUNNEL_ONLY".
* **proxy_mode** - `str` - A proxy mode to operate in. May be "FORWARD",
or "TUNNEL".
* **ssl_context** - `Optional[SSLContext]` - An SSL context to use for
verifying connections.
* **max_connections** - `Optional[int]` - The maximum number of concurrent
Expand All @@ -32,15 +32,15 @@ class AsyncHTTPProxy(AsyncConnectionPool):
def __init__(
self,
proxy_origin: Origin,
proxy_mode: str,
proxy_headers: Headers = None,
proxy_mode: str = "DEFAULT",
ssl_context: SSLContext = None,
max_connections: int = None,
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
):
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
assert proxy_mode in ("FORWARD", "TUNNEL")

self.proxy_origin = proxy_origin
self.proxy_headers = [] if proxy_headers is None else proxy_headers
Expand All @@ -64,15 +64,11 @@ async def request(
if self._keepalive_expiry is not None:
await self._keepalive_sweep()

if (
self.proxy_mode == "DEFAULT" and url[0] == b"http"
) or self.proxy_mode == "FORWARD_ONLY":
# By default HTTP requests should be forwarded.
if self.proxy_mode == "FORWARD":
return await self._forward_request(
method, url, headers=headers, stream=stream, timeout=timeout
)
else:
# By default HTTPS should be tunnelled.
return await self._tunnel_request(
method, url, headers=headers, stream=stream, timeout=timeout
)
Expand Down Expand Up @@ -144,9 +140,8 @@ async def _tunnel_request(
# [proxy-headers]
target = b"%b:%d" % (url[1], url[2])
connect_url = self.proxy_origin + (target,)
proxy_headers = self._get_tunnel_proxy_headers(headers)
proxy_response = await proxy_connection.request(
b"CONNECT", connect_url, headers=proxy_headers, timeout=timeout
b"CONNECT", connect_url, headers=self.proxy_headers, timeout=timeout
)
proxy_status_code = proxy_response[1]
proxy_reason_phrase = proxy_response[2]
Expand Down Expand Up @@ -182,34 +177,3 @@ async def _tunnel_request(
response[4], connection=connection, callback=self._response_closed
)
return response[0], response[1], response[2], response[3], wrapped_stream

def _get_tunnel_proxy_headers(self, request_headers: Headers = None) -> Headers:
"""Returns the headers for the CONNECT request to the tunnel proxy.
These do not include _all_ the request headers, but we make sure Host
is present as it's required for any h11 connection. If not in the proxy
headers we try to pull it from the request headers.
We also attach `Accept: */*` if not present in the user's proxy headers.
"""
proxy_headers = []
should_add_accept_header = True
should_add_host_header = True
for header in self.proxy_headers:
proxy_headers.append(header)
if header[0] == b"accept":
should_add_accept_header = False
if header[0] == b"host":
should_add_host_header = False

if should_add_accept_header:
proxy_headers.append((b"accept", b"*/*"))

if should_add_host_header and request_headers:
try:
host_header = next(h for h in request_headers if h[0] == b"host")
proxy_headers.append(host_header)
except StopIteration:
pass

return proxy_headers
48 changes: 6 additions & 42 deletions httpcore/_sync/http_proxy.py
Expand Up @@ -18,8 +18,8 @@ class SyncHTTPProxy(SyncConnectionPool):
service as a 3-tuple of (scheme, host, port).
* **proxy_headers** - `Optional[List[Tuple[bytes, bytes]]]` - A list of
proxy headers to include.
* **proxy_mode** - `str` - A proxy mode to operate in. May be "DEFAULT",
"FORWARD_ONLY", or "TUNNEL_ONLY".
* **proxy_mode** - `str` - A proxy mode to operate in. May be "FORWARD",
or "TUNNEL".
* **ssl_context** - `Optional[SSLContext]` - An SSL context to use for
verifying connections.
* **max_connections** - `Optional[int]` - The maximum number of concurrent
Expand All @@ -32,15 +32,15 @@ class SyncHTTPProxy(SyncConnectionPool):
def __init__(
self,
proxy_origin: Origin,
proxy_mode: str,
proxy_headers: Headers = None,
proxy_mode: str = "DEFAULT",
ssl_context: SSLContext = None,
max_connections: int = None,
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
):
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
assert proxy_mode in ("FORWARD", "TUNNEL")

self.proxy_origin = proxy_origin
self.proxy_headers = [] if proxy_headers is None else proxy_headers
Expand All @@ -64,15 +64,11 @@ def request(
if self._keepalive_expiry is not None:
self._keepalive_sweep()

if (
self.proxy_mode == "DEFAULT" and url[0] == b"http"
) or self.proxy_mode == "FORWARD_ONLY":
# By default HTTP requests should be forwarded.
if self.proxy_mode == "FORWARD":
return self._forward_request(
method, url, headers=headers, stream=stream, timeout=timeout
)
else:
# By default HTTPS should be tunnelled.
return self._tunnel_request(
method, url, headers=headers, stream=stream, timeout=timeout
)
Expand Down Expand Up @@ -144,9 +140,8 @@ def _tunnel_request(
# [proxy-headers]
target = b"%b:%d" % (url[1], url[2])
connect_url = self.proxy_origin + (target,)
proxy_headers = self._get_tunnel_proxy_headers(headers)
proxy_response = proxy_connection.request(
b"CONNECT", connect_url, headers=proxy_headers, timeout=timeout
b"CONNECT", connect_url, headers=self.proxy_headers, timeout=timeout
)
proxy_status_code = proxy_response[1]
proxy_reason_phrase = proxy_response[2]
Expand Down Expand Up @@ -182,34 +177,3 @@ def _tunnel_request(
response[4], connection=connection, callback=self._response_closed
)
return response[0], response[1], response[2], response[3], wrapped_stream

def _get_tunnel_proxy_headers(self, request_headers: Headers = None) -> Headers:
"""Returns the headers for the CONNECT request to the tunnel proxy.
These do not include _all_ the request headers, but we make sure Host
is present as it's required for any h11 connection. If not in the proxy
headers we try to pull it from the request headers.
We also attach `Accept: */*` if not present in the user's proxy headers.
"""
proxy_headers = []
should_add_accept_header = True
should_add_host_header = True
for header in self.proxy_headers:
proxy_headers.append(header)
if header[0] == b"accept":
should_add_accept_header = False
if header[0] == b"host":
should_add_host_header = False

if should_add_accept_header:
proxy_headers.append((b"accept", b"*/*"))

if should_add_host_header and request_headers:
try:
host_header = next(h for h in request_headers if h[0] == b"host")
proxy_headers.append(host_header)
except StopIteration:
pass

return proxy_headers
15 changes: 10 additions & 5 deletions tests/async_tests/test_interfaces.py
Expand Up @@ -177,15 +177,20 @@ async def test_http_request_cannot_reuse_dropped_connection() -> None:
assert len(http._connections[url[:3]]) == 1 # type: ignore


@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"])
@pytest.mark.parametrize("proxy_mode", ["FORWARD", "TUNNEL"])
@pytest.mark.usefixtures("async_environment")
async def test_http_proxy(
proxy_server: typing.Tuple[bytes, bytes, int], proxy_mode: str
) -> None:
async with httpcore.AsyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
# Tunnel requires the host header to be present,
# Forwarding will use the request headers
proxy_headers = headers if proxy_mode == "TUNNEL" else None
async with httpcore.AsyncHTTPProxy(
proxy_server, proxy_mode, proxy_headers=proxy_headers
) as http:
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
)
Expand Down
15 changes: 10 additions & 5 deletions tests/sync_tests/test_interfaces.py
Expand Up @@ -177,15 +177,20 @@ def test_http_request_cannot_reuse_dropped_connection() -> None:
assert len(http._connections[url[:3]]) == 1 # type: ignore


@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"])
@pytest.mark.parametrize("proxy_mode", ["FORWARD", "TUNNEL"])

def test_http_proxy(
proxy_server: typing.Tuple[bytes, bytes, int], proxy_mode: str
) -> None:
with httpcore.SyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
# Tunnel requires the host header to be present,
# Forwarding will use the request headers
proxy_headers = headers if proxy_mode == "TUNNEL" else None
with httpcore.SyncHTTPProxy(
proxy_server, proxy_mode, proxy_headers=proxy_headers
) as http:
http_version, status_code, reason, headers, stream = http.request(
method, url, headers
)
Expand Down

0 comments on commit 7066bf0

Please sign in to comment.