Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for family to aiohttp session helper #102702

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 37 additions & 14 deletions homeassistant/helpers/aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ssl import SSLContext
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

import aiohttp
from aiohttp import web
Expand All @@ -29,9 +29,8 @@


DATA_CONNECTOR = "aiohttp_connector"
DATA_CONNECTOR_NOTVERIFY = "aiohttp_connector_notverify"
DATA_CLIENTSESSION = "aiohttp_clientsession"
DATA_CLIENTSESSION_NOTVERIFY = "aiohttp_clientsession_notverify"

SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format(
APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info
)
Expand Down Expand Up @@ -88,22 +87,31 @@ async def json(
@callback
@bind_hass
def async_get_clientsession(
hass: HomeAssistant, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession.

This method must be run in the event loop.
"""
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
session_key = _make_key(verify_ssl, family)
if DATA_CLIENTSESSION not in hass.data:
sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {}
hass.data[DATA_CLIENTSESSION] = sessions
else:
sessions = hass.data[DATA_CLIENTSESSION]

if key not in hass.data:
hass.data[key] = _async_create_clientsession(
if session_key not in sessions:
session = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown,
family=family,
)
sessions[session_key] = session
else:
session = sessions[session_key]

return cast(aiohttp.ClientSession, hass.data[key])
return session


@callback
Expand All @@ -112,6 +120,7 @@ def async_create_clientsession(
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup: bool = True,
family: int = 0,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies.
Expand All @@ -131,6 +140,7 @@ def async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=auto_cleanup_method,
family=family,
**kwargs,
)

Expand All @@ -143,11 +153,12 @@ def _async_create_clientsession(
verify_ssl: bool = True,
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None,
family: int = 0,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl),
connector=_async_get_connector(hass, verify_ssl, family),
json_serialize=json_dumps,
response_class=HassClientResponse,
**kwargs,
Expand Down Expand Up @@ -275,31 +286,43 @@ def _async_close_websession(event: Event) -> None:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)


@callback
def _make_key(verify_ssl: bool = True, family: int = 0) -> tuple[bool, int]:
"""Make a key for connector or session pool."""
return (verify_ssl, family)
bdraco marked this conversation as resolved.
Show resolved Hide resolved


@callback
def _async_get_connector(
hass: HomeAssistant, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True, family: int = 0
) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp.

This method must be run in the event loop.
"""
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY
connector_key = _make_key(verify_ssl, family)
if DATA_CONNECTOR not in hass.data:
connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {}
hass.data[DATA_CONNECTOR] = connectors
else:
connectors = hass.data[DATA_CONNECTOR]

if key in hass.data:
return cast(aiohttp.BaseConnector, hass.data[key])
if connector_key in connectors:
return connectors[connector_key]

if verify_ssl:
ssl_context: bool | SSLContext = ssl_util.get_default_context()
else:
ssl_context = ssl_util.get_default_no_verify_context()

connector = aiohttp.TCPConnector(
family=family,
enable_cleanup_closed=ENABLE_CLEANUP_CLOSED,
ssl=ssl_context,
limit=MAXIMUM_CONNECTIONS,
limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST,
)
hass.data[key] = connector
connectors[connector_key] = connector

async def _async_close_connector(event: Event) -> None:
"""Close connector pool."""
Expand Down
4 changes: 2 additions & 2 deletions tests/components/demo/test_media_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Platform,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION
from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION, _make_key
from homeassistant.setup import async_setup_component

from tests.typing import ClientSessionGenerator
Expand Down Expand Up @@ -483,7 +483,7 @@ async def get(self, url):
def detach(self):
"""Test websession detach."""

hass.data[DATA_CLIENTSESSION] = MockWebsession()
hass.data[DATA_CLIENTSESSION] = {_make_key(): MockWebsession()}

state = hass.states.get(TEST_ENTITY_ID)
assert state.state == STATE_PLAYING
Expand Down
96 changes: 65 additions & 31 deletions tests/helpers/test_aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,53 @@ def camera_client_fixture(hass, hass_client):
async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass)
verify_ssl = True
family = 0

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession without ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
verify_ssl = False
family = 0

assert isinstance(
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession combinations."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None:
"""Test create clientsession with ssl."""
session = client.async_create_clientsession(hass, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)

verify_ssl = True
family = 0

assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)


async def test_create_clientsession_without_ssl_and_cookies(
Expand All @@ -80,46 +107,53 @@ async def test_create_clientsession_without_ssl_and_cookies(
"""Test create clientsession without ssl."""
session = client.async_create_clientsession(hass, False, cookies={"bla": True})
assert isinstance(session, aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)

verify_ssl = False
family = 0

async def test_get_clientsession_cleanup(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass)
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
assert isinstance(connector, aiohttp.TCPConnector)

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)

hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()

assert hass.data[client.DATA_CLIENTSESSION].closed
assert hass.data[client.DATA_CONNECTOR].closed


async def test_get_clientsession_cleanup_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
)
async def test_get_clientsession_cleanup(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
) -> None:
"""Test init clientsession cleanup."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)

assert isinstance(
hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession
)
assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
assert isinstance(connector, aiohttp.TCPConnector)

hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()

assert hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed
assert hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed
assert client_session.closed
assert connector.closed


async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
"""Test closing clientsession does not work."""

verify_ssl = True
family = 0

with patch("aiohttp.ClientSession.close") as mock_close:
session = client.async_get_clientsession(hass)

assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession)
assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
assert isinstance(
hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)],
aiohttp.ClientSession,
)
assert isinstance(
hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector
)

with pytest.raises(RuntimeError):
await session.close()
Expand Down