From dd624c0b7011597899c08e6e02aa1be612efb233 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 24 Oct 2023 14:40:44 -0500 Subject: [PATCH 1/2] Add support for family to aiohttp session helper needed for #98003 --- homeassistant/helpers/aiohttp_client.py | 51 +++++++++---- tests/helpers/test_aiohttp_client.py | 96 +++++++++++++++++-------- 2 files changed, 102 insertions(+), 45 deletions(-) diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index 20351efff530a8..b8d810d899b7b8 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -7,7 +7,7 @@ from ssl import SSLContext import sys from types import MappingProxyType -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import aiohttp from aiohttp import web @@ -29,9 +29,8 @@ DATA_CONNECTOR = "aiohttp_connector" -DATA_CONNECTOR_NOTVERIFY = "aiohttp_connector_notverify" DATA_CLIENTSESSION = "aiohttp_clientsession" -DATA_CLIENTSESSION_NOTVERIFY = "aiohttp_clientsession_notverify" + SERVER_SOFTWARE = "{0}/{1} aiohttp/{2} Python/{3[0]}.{3[1]}".format( APPLICATION_NAME, __version__, aiohttp.__version__, sys.version_info ) @@ -88,22 +87,31 @@ async def json( @callback @bind_hass def async_get_clientsession( - hass: HomeAssistant, verify_ssl: bool = True + hass: HomeAssistant, verify_ssl: bool = True, family: int = 0 ) -> aiohttp.ClientSession: """Return default aiohttp ClientSession. This method must be run in the event loop. """ - key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY + session_key = _make_key(verify_ssl, family) + if DATA_CLIENTSESSION not in hass.data: + sessions: dict[tuple[bool, int], aiohttp.ClientSession] = {} + hass.data[DATA_CLIENTSESSION] = sessions + else: + sessions = hass.data[DATA_CLIENTSESSION] - if key not in hass.data: - hass.data[key] = _async_create_clientsession( + if session_key not in sessions: + session = _async_create_clientsession( hass, verify_ssl, auto_cleanup_method=_async_register_default_clientsession_shutdown, + family=family, ) + sessions[session_key] = session + else: + session = sessions[session_key] - return cast(aiohttp.ClientSession, hass.data[key]) + return session @callback @@ -112,6 +120,7 @@ def async_create_clientsession( hass: HomeAssistant, verify_ssl: bool = True, auto_cleanup: bool = True, + family: int = 0, **kwargs: Any, ) -> aiohttp.ClientSession: """Create a new ClientSession with kwargs, i.e. for cookies. @@ -131,6 +140,7 @@ def async_create_clientsession( hass, verify_ssl, auto_cleanup_method=auto_cleanup_method, + family=family, **kwargs, ) @@ -143,11 +153,12 @@ def _async_create_clientsession( verify_ssl: bool = True, auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None] | None = None, + family: int = 0, **kwargs: Any, ) -> aiohttp.ClientSession: """Create a new ClientSession with kwargs, i.e. for cookies.""" clientsession = aiohttp.ClientSession( - connector=_async_get_connector(hass, verify_ssl), + connector=_async_get_connector(hass, verify_ssl, family), json_serialize=json_dumps, response_class=HassClientResponse, **kwargs, @@ -275,18 +286,29 @@ def _async_close_websession(event: Event) -> None: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _async_close_websession) +@callback +def _make_key(verify_ssl: bool = True, family: int = 0) -> tuple[bool, int]: + """Make a key for connector or session pool.""" + return (verify_ssl, family) + + @callback def _async_get_connector( - hass: HomeAssistant, verify_ssl: bool = True + hass: HomeAssistant, verify_ssl: bool = True, family: int = 0 ) -> aiohttp.BaseConnector: """Return the connector pool for aiohttp. This method must be run in the event loop. """ - key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY + connector_key = _make_key(verify_ssl, family) + if DATA_CONNECTOR not in hass.data: + connectors: dict[tuple[bool, int], aiohttp.BaseConnector] = {} + hass.data[DATA_CONNECTOR] = connectors + else: + connectors = hass.data[DATA_CONNECTOR] - if key in hass.data: - return cast(aiohttp.BaseConnector, hass.data[key]) + if connector_key in connectors: + return connectors[connector_key] if verify_ssl: ssl_context: bool | SSLContext = ssl_util.get_default_context() @@ -294,12 +316,13 @@ def _async_get_connector( ssl_context = ssl_util.get_default_no_verify_context() connector = aiohttp.TCPConnector( + family=family, enable_cleanup_closed=ENABLE_CLEANUP_CLOSED, ssl=ssl_context, limit=MAXIMUM_CONNECTIONS, limit_per_host=MAXIMUM_CONNECTIONS_PER_HOST, ) - hass.data[key] = connector + connectors[connector_key] = connector async def _async_close_connector(event: Event) -> None: """Close connector pool.""" diff --git a/tests/helpers/test_aiohttp_client.py b/tests/helpers/test_aiohttp_client.py index daeb324b19f1e3..46b389722e89b3 100644 --- a/tests/helpers/test_aiohttp_client.py +++ b/tests/helpers/test_aiohttp_client.py @@ -52,26 +52,53 @@ def camera_client_fixture(hass, hass_client): async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None: """Test init clientsession with ssl.""" client.async_get_clientsession(hass) + verify_ssl = True + family = 0 - assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) - assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) + client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)] + assert isinstance(client_session, aiohttp.ClientSession) + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + assert isinstance(connector, aiohttp.TCPConnector) async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None: """Test init clientsession without ssl.""" client.async_get_clientsession(hass, verify_ssl=False) + verify_ssl = False + family = 0 - assert isinstance( - hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession - ) - assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector) + client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)] + assert isinstance(client_session, aiohttp.ClientSession) + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + assert isinstance(connector, aiohttp.TCPConnector) + + +@pytest.mark.parametrize( + ("verify_ssl", "expected_family"), + [(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)], +) +async def test_get_clientsession( + hass: HomeAssistant, verify_ssl: bool, expected_family: int +) -> None: + """Test init clientsession combinations.""" + client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family) + client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)] + assert isinstance(client_session, aiohttp.ClientSession) + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)] + assert isinstance(connector, aiohttp.TCPConnector) async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> None: """Test create clientsession with ssl.""" session = client.async_create_clientsession(hass, cookies={"bla": True}) assert isinstance(session, aiohttp.ClientSession) - assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) + + verify_ssl = True + family = 0 + + assert client.DATA_CLIENTSESSION not in hass.data + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + assert isinstance(connector, aiohttp.TCPConnector) async def test_create_clientsession_without_ssl_and_cookies( @@ -80,46 +107,53 @@ async def test_create_clientsession_without_ssl_and_cookies( """Test create clientsession without ssl.""" session = client.async_create_clientsession(hass, False, cookies={"bla": True}) assert isinstance(session, aiohttp.ClientSession) - assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector) + verify_ssl = False + family = 0 -async def test_get_clientsession_cleanup(hass: HomeAssistant) -> None: - """Test init clientsession with ssl.""" - client.async_get_clientsession(hass) + assert client.DATA_CLIENTSESSION not in hass.data + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + assert isinstance(connector, aiohttp.TCPConnector) - assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) - assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) - hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE) - await hass.async_block_till_done() - - assert hass.data[client.DATA_CLIENTSESSION].closed - assert hass.data[client.DATA_CONNECTOR].closed - - -async def test_get_clientsession_cleanup_without_ssl(hass: HomeAssistant) -> None: - """Test init clientsession with ssl.""" - client.async_get_clientsession(hass, verify_ssl=False) +@pytest.mark.parametrize( + ("verify_ssl", "expected_family"), + [(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)], +) +async def test_get_clientsession_cleanup( + hass: HomeAssistant, verify_ssl: bool, expected_family: int +) -> None: + """Test init clientsession cleanup.""" + client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family) - assert isinstance( - hass.data[client.DATA_CLIENTSESSION_NOTVERIFY], aiohttp.ClientSession - ) - assert isinstance(hass.data[client.DATA_CONNECTOR_NOTVERIFY], aiohttp.TCPConnector) + client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)] + assert isinstance(client_session, aiohttp.ClientSession) + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)] + assert isinstance(connector, aiohttp.TCPConnector) hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE) await hass.async_block_till_done() - assert hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed - assert hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed + assert client_session.closed + assert connector.closed async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None: """Test closing clientsession does not work.""" + + verify_ssl = True + family = 0 + with patch("aiohttp.ClientSession.close") as mock_close: session = client.async_get_clientsession(hass) - assert isinstance(hass.data[client.DATA_CLIENTSESSION], aiohttp.ClientSession) - assert isinstance(hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector) + assert isinstance( + hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)], + aiohttp.ClientSession, + ) + assert isinstance( + hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector + ) with pytest.raises(RuntimeError): await session.close() From b52aa6ff8313bec870928148ed7438b39b6a8be1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 24 Oct 2023 16:31:35 -0500 Subject: [PATCH 2/2] fix test --- tests/components/demo/test_media_player.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/components/demo/test_media_player.py b/tests/components/demo/test_media_player.py index ff6274af1b575f..b1bd77a74a1469 100644 --- a/tests/components/demo/test_media_player.py +++ b/tests/components/demo/test_media_player.py @@ -16,7 +16,7 @@ Platform, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION +from homeassistant.helpers.aiohttp_client import DATA_CLIENTSESSION, _make_key from homeassistant.setup import async_setup_component from tests.typing import ClientSessionGenerator @@ -483,7 +483,7 @@ async def get(self, url): def detach(self): """Test websession detach.""" - hass.data[DATA_CLIENTSESSION] = MockWebsession() + hass.data[DATA_CLIENTSESSION] = {_make_key(): MockWebsession()} state = hass.states.get(TEST_ENTITY_ID) assert state.state == STATE_PLAYING