Skip to content

Commit

Permalink
Add support for family to aiohttp session helper
Browse files Browse the repository at this point in the history
needed for #98003
  • Loading branch information
bdraco committed Oct 24, 2023
1 parent f733f20 commit dd624c0
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 45 deletions.
51 changes: 37 additions & 14 deletions homeassistant/helpers/aiohttp_client.py
Expand Up @@ -7,7 +7,7 @@
from ssl import SSLContext
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

import aiohttp
from aiohttp import web
Expand All @@ -29,9 +29,8 @@


DATA_CONNECTOR = "aiohttp_connector"
DATA_CONNECTOR_NOTVERIFY = "aiohttp_connector_notverify"
DATA_CLIENTSESSION = "aiohttp_clientsession"
DATA_CLIENTSESSION_NOTVERIFY = "aiohttp_clientsession_notverify"

SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format(
APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info
)
Expand Down Expand Up @@ -88,22 +87,31 @@ async def json(
@callback
@bind_hass
def async_get_clientsession(
hass: HomeAssistant, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession.
This method must be run in the event loop.
"""
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
session_key = _make_key(verify_ssl, family)
if DATA_CLIENTSESSION not in hass.data:
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
hass.data[DATA_CLIENTSESSION] = sessions
else:
sessions = hass.data[DATA_CLIENTSESSION]

if key not in hass.data:
hass.data[key] = _async_create_clientsession(
if session_key not in sessions:
session = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown,
family=family,
)
sessions[session_key] = session
else:
session = sessions[session_key]

return cast(aiohttp.ClientSession, hass.data[key])
return session


@callback
Expand All @@ -112,6 +120,7 @@ def async_create_clientsession(
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup: bool = True,
family: int = 0,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies.
Expand All @@ -131,6 +140,7 @@ def async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=auto_cleanup_method,
family=family,
**kwargs,
)

Expand All @@ -143,11 +153,12 @@ def _async_create_clientsession(
verify_ssl: bool = True,
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None,
family: int = 0,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl),
connector=_async_get_connector(hass, verify_ssl, family),
json_serialize=json_dumps,
response_class=HassClientResponse,
**kwargs,
Expand Down Expand Up @@ -275,31 +286,43 @@ def _async_close_websession(event: Event) -> None:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)


@callback
def _make_key(verify_ssl: bool = True, family: int = 0) -> tuple[bool, int]:
"""Make a key for connector or session pool."""
return (verify_ssl, family)


@callback
def _async_get_connector(
hass: HomeAssistant, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp.
This method must be run in the event loop.
"""
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY
connector_key = _make_key(verify_ssl, family)
if DATA_CONNECTOR not in hass.data:
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
hass.data[DATA_CONNECTOR] = connectors
else:
connectors = hass.data[DATA_CONNECTOR]

if key in hass.data:
return cast(aiohttp.BaseConnector, hass.data[key])
if connector_key in connectors:
return connectors[connector_key]

if verify_ssl:
ssl_context: bool | SSLContext = ssl_util.get_default_context()
else:
ssl_context = ssl_util.get_default_no_verify_context()

connector = aiohttp.TCPConnector(
family=family,
enable_cleanup_closed=ENABLE_CLEANUP_CLOSED,
ssl=ssl_context,
limit=MAXIMUM_CONNECTIONS,
limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
)
hass.data[key] = connector
connectors[connector_key] = connector

async def _async_close_connector(event: Event) -> None:
"""Close connector pool."""
Expand Down
96 changes: 65 additions & 31 deletions tests/helpers/test_aiohttp_client.py
Expand Up @@ -52,26 +52,53 @@ def camera_client_fixture(hass, hass_client):
async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass)
verify_ssl = True
family = 0

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession without ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
verify_ssl = False
family = 0

assert isinstance(
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession combinations."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None:
"""Test create clientsession with ssl."""
session = client.async_create_clientsession(hass, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)

verify_ssl = True
family = 0

assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_create_clientsession_without_ssl_and_cookies(
Expand All @@ -80,46 +107,53 @@ async def test_create_clientsession_without_ssl_and_cookies(
"""Test create clientsession without ssl."""
session = client.async_create_clientsession(hass, False, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)

verify_ssl = False
family = 0

async def test_get_clientsession_cleanup(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass)
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)

hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()

assert hass.data[client.DATA_CLIENTSESSION].closed
assert hass.data[client.DATA_CONNECTOR].closed


async def test_get_clientsession_cleanup_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession_cleanup(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession cleanup."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)

assert isinstance(
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)

hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()

assert hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed
assert hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed
assert client_session.closed
assert connector.closed


async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
"""Test closing clientsession does not work."""

verify_ssl = True
family = 0

with patch("aiohttp.ClientSession.close") as mock_close:
session = client.async_get_clientsession(hass)

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
assert isinstance(
hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)],
aiohttp.ClientSession,
)
assert isinstance(
hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector
)

with pytest.raises(RuntimeError):
await session.close()
Expand Down

0 comments on commit dd624c0

Please sign in to comment.