Skip to content

Commit

Permalink
Close outstanding connections (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchristie committed May 27, 2020
1 parent 687e9fb commit 353509d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 3 deletions.
5 changes: 5 additions & 0 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,8 @@ async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
await self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
self.socket = self.connection.socket

async def aclose(self) -> None:
async with self.request_lock:
if self.connection is not None:
await self.connection.aclose()
2 changes: 1 addition & 1 deletion httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def request(
stream: AsyncByteStream = None,
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, Headers, AsyncByteStream]:
assert url[0] in (b'http', b'https')
assert url[0] in (b"http", b"https")
origin = url_to_origin(url)

if self._keepalive_expiry is not None:
Expand Down
5 changes: 5 additions & 0 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,8 @@ def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
self.socket = self.connection.socket

def close(self) -> None:
with self.request_lock:
if self.connection is not None:
self.connection.close()
2 changes: 1 addition & 1 deletion httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def request(
stream: SyncByteStream = None,
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, Headers, SyncByteStream]:
assert url[0] in (b'http', b'https')
assert url[0] in (b"http", b"https")
origin = url_to_origin(url)

if self._keepalive_expiry is not None:
Expand Down
3 changes: 2 additions & 1 deletion httpcore/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import typing

from ._types import URL, Origin

_LOGGER_INITIALIZED = False
Expand Down Expand Up @@ -52,6 +53,6 @@ def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None:

def url_to_origin(url: URL) -> Origin:
scheme, host, explicit_port = url[:3]
default_port = {b'http': 80, b'https': 443}[scheme]
default_port = {b"http": 80, b"https": 443}[scheme]
port = default_port if explicit_port is None else explicit_port
return scheme, host, port

0 comments on commit 353509d

Please sign in to comment.