diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 02a9dd9dade605..2b8f1ec4065a0a 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -223,6 +223,7 @@ class ConfigEntry: "_async_cancel_retry_setup", "_on_unload", "reload_lock", + "_reauth_lock", "_tasks", "_background_tasks", "_integration_for_domain", @@ -321,6 +322,8 @@ def __init__( # Reload lock to prevent conflicting reloads self.reload_lock = asyncio.Lock() + # Reauth lock to prevent concurrent reauth flows + self._reauth_lock = asyncio.Lock() self._tasks: set[asyncio.Future[Any]] = set() self._background_tasks: set[asyncio.Future[Any]] = set() @@ -727,12 +730,28 @@ def async_start_reauth( data: dict[str, Any] | None = None, ) -> None: """Start a reauth flow.""" + # We will check this again in the task when we hold the lock, + # but we also check it now to try to avoid creating the task. if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})): # Reauth flow already in progress for this entry return - hass.async_create_task( - hass.config_entries.flow.async_init( + self._async_init_reauth(hass, context, data), + f"config entry reauth {self.title} {self.domain} {self.entry_id}", + ) + + async def _async_init_reauth( + self, + hass: HomeAssistant, + context: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + ) -> None: + """Start a reauth flow.""" + async with self._reauth_lock: + if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})): + # Reauth flow already in progress for this entry + return + await hass.config_entries.flow.async_init( self.domain, context={ "source": SOURCE_REAUTH, @@ -742,9 +761,7 @@ def async_start_reauth( } | (context or {}), data=self.data | (data or {}), - ), - f"config entry reauth {self.title} {self.domain} {self.entry_id}", - ) + ) @callback def async_get_active_flows( @@ -754,7 +771,9 @@ def async_get_active_flows( return ( flow for flow in hass.config_entries.flow.async_progress_by_handler( - self.domain, match_context={"entry_id": self.entry_id} + self.domain, + match_context={"entry_id": self.entry_id}, + include_uninitialized=True, ) if flow["context"].get("source") in sources ) diff --git a/tests/components/smarttub/test_init.py b/tests/components/smarttub/test_init.py index 0e88f3ed7c7b9a..929ad687e11e04 100644 --- a/tests/components/smarttub/test_init.py +++ b/tests/components/smarttub/test_init.py @@ -42,6 +42,7 @@ async def test_setup_auth_failed( config_entry.add_to_hass(hass) with patch.object(hass.config_entries.flow, "async_init") as mock_flow_init: await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() assert config_entry.state is ConfigEntryState.SETUP_ERROR mock_flow_init.assert_called_with( DOMAIN, diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index d17c724cb2aa67..eb771b7e6a67b3 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3791,6 +3791,20 @@ async def test_reauth(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert len(hass.config_entries.flow.async_progress()) == 2 + # Abort all existing flows + for flow in hass.config_entries.flow.async_progress(): + hass.config_entries.flow.async_abort(flow["flow_id"]) + await hass.async_block_till_done() + + # Check that we can't start duplicate reauth flows + # without blocking between flows + entry.async_start_reauth(hass, {"extra_context": "some_extra_context"}) + entry.async_start_reauth(hass, {"extra_context": "some_extra_context"}) + entry.async_start_reauth(hass, {"extra_context": "some_extra_context"}) + entry.async_start_reauth(hass, {"extra_context": "some_extra_context"}) + await hass.async_block_till_done() + assert len(hass.config_entries.flow.async_progress()) == 1 + async def test_get_active_flows(hass: HomeAssistant) -> None: """Test the async_get_active_flows helper."""