Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local address support. #100

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(
http2: bool = False,
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
local_addr: bytes = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check - Do we want local_addr or source_addr here?
What are trio, asyncio, and the sync stdlib using for their naming?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio uses local_addr, trio uses local_address, and the stdlib socket module uses source_address (and curio uses source_addr). There's really no consistency.

):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_addr = local_addr

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -98,7 +100,7 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
hostname, port, ssl_context, timeout, self.local_addr
)
except Exception:
self.connect_failed = True
Expand Down
8 changes: 7 additions & 1 deletion httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class AsyncConnectionPool(AsyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **local_addr** - `Optional[bytes]` - Local address to connect from
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Miniature nitpick. Let's use a full stop to match the other cases. "Local address to connect from." 😃

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"""

def __init__(
Expand All @@ -85,12 +86,14 @@ def __init__(
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
local_addr: bytes = None,
):
self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._local_addr = local_addr
self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = AutoBackend()
Expand Down Expand Up @@ -141,7 +144,10 @@ async def request(

if connection is None:
connection = AsyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context,
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
local_addr=self._local_addr,
)
logger.trace("created connection=%r", connection)
await self._add_to_pool(connection, timeout=timeout)
Expand Down
9 changes: 8 additions & 1 deletion httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,20 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
local_addr: Optional[bytes],
) -> SocketStream:
host = hostname.decode("ascii")
connect_timeout = timeout.get("connect")
exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError}
with map_exceptions(exc_map):
local_addrport = None
if local_addr:
local_addrport = (local_addr, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_addrport = None if local_addr is None else (local_addr, 0)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

stream_reader, stream_writer = await asyncio.wait_for(
asyncio.open_connection(host, port, ssl=ssl_context), connect_timeout,
asyncio.open_connection(
host, port, ssl=ssl_context, local_addr=local_addrport
),
connect_timeout,
)
return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer
Expand Down
5 changes: 4 additions & 1 deletion httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
local_addr: Optional[bytes],
) -> AsyncSocketStream:
return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout, local_addr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon we ought to use .open_tcp_stream(..., local_addr=local_addr) here.
Just makes it super obvious visually that it's an optional extra.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

)

def create_lock(self) -> AsyncLock:
return self.backend.create_lock()
Expand Down
1 change: 1 addition & 0 deletions httpcore/_backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
local_addr: Optional[bytes],
) -> AsyncSocketStream:
raise NotImplementedError() # pragma: no cover

Expand Down
6 changes: 5 additions & 1 deletion httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,17 @@ def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
local_addr: Optional[bytes],
) -> SyncSocketStream:
address = (hostname.decode("ascii"), port)
connect_timeout = timeout.get("connect")
exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError}

with map_exceptions(exc_map):
sock = socket.create_connection(address, connect_timeout)
local_addrport = None
if local_addr:
local_addrport = (local_addr, 0)
sock = socket.create_connection(address, connect_timeout, local_addrport) # type: ignore
if ssl_context is not None:
sock = ssl_context.wrap_socket(
sock, server_hostname=hostname.decode("ascii")
Expand Down
3 changes: 3 additions & 0 deletions httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
local_addr: Optional[bytes],
) -> AsyncSocketStream:
if local_addr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment that it is currently supported in trio master, and should be expected from version 0.16.1 onwards?

Also, what will our implementation change look like once trio support does land here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding a comment.

I believe the code would just add something like local_address=local_addr to the open_tcp_stream call, although it's possible that there would be a type mismatch, as the local_addr parameter is None or a bytes, and trio claims to want None or a str (although it just passes it into socket.bind, which should accept a bytes).

raise NotImplementedError()
connect_timeout = none_as_inf(timeout.get("connect"))
exc_map = {
trio.TooSlowError: ConnectTimeout,
Expand Down
4 changes: 3 additions & 1 deletion httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(
http2: bool = False,
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
local_addr: bytes = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_addr = local_addr

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -98,7 +100,7 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
hostname, port, ssl_context, timeout, self.local_addr
)
except Exception:
self.connect_failed = True
Expand Down
8 changes: 7 additions & 1 deletion httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class SyncConnectionPool(SyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **local_addr** - `Optional[bytes]` - Local address to connect from
"""

def __init__(
Expand All @@ -85,12 +86,14 @@ def __init__(
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
local_addr: bytes = None,
):
self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._local_addr = local_addr
self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = SyncBackend()
Expand Down Expand Up @@ -141,7 +144,10 @@ def request(

if connection is None:
connection = SyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context,
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
local_addr=self._local_addr,
)
logger.trace("created connection=%r", connection)
self._add_to_pool(connection, timeout=timeout)
Expand Down
19 changes: 19 additions & 0 deletions tests/async_tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ async def test_http_proxy(
assert reason == b"OK"


@pytest.mark.parametrize("local_addr", [b"0.0.0.0"])
@pytest.mark.asyncio
# This doesn't run with trio, since trio doesn't support local_addr.
async def test_http_request_local_addr(local_addr: str) -> None:
async with httpcore.AsyncConnectionPool(local_addr=local_addr) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
)
body = await read_body(stream)

assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
assert len(http._connections[url[:3]]) == 1 # type: ignore


# mitmproxy does not support forwarding HTTPS requests
@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "TUNNEL_ONLY"])
@pytest.mark.usefixtures("async_environment")
Expand Down
19 changes: 19 additions & 0 deletions tests/sync_tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ def test_http_proxy(
assert reason == b"OK"


@pytest.mark.parametrize("local_addr", [b"0.0.0.0"])

# This doesn't run with trio, since trio doesn't support local_addr.
def test_http_request_local_addr(local_addr: str) -> None:
with httpcore.SyncConnectionPool(local_addr=local_addr) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
http_version, status_code, reason, headers, stream = http.request(
method, url, headers
)
body = read_body(stream)

assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
assert len(http._connections[url[:3]]) == 1 # type: ignore


# mitmproxy does not support forwarding HTTPS requests
@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "TUNNEL_ONLY"])

Expand Down