Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in starting reauth flows #103130

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class ConfigEntry:
"_async_cancel_retry_setup",
"_on_unload",
"reload_lock",
"_reauth_lock",
"_tasks",
"_background_tasks",
"_integration_for_domain",
Expand Down Expand Up @@ -321,6 +322,8 @@ def __init__(

# Reload lock to prevent conflicting reloads
self.reload_lock = asyncio.Lock()
# Reauth lock to prevent concurrent reauth flows
self._reauth_lock = asyncio.Lock()

self._tasks: set[asyncio.Future[Any]] = set()
self._background_tasks: set[asyncio.Future[Any]] = set()
Expand Down Expand Up @@ -727,12 +730,28 @@ def async_start_reauth(
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
# We will check this again in the task when we hold the lock,
# but we also check it now to try to avoid creating the task.
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
# Reauth flow already in progress for this entry
return
bdraco marked this conversation as resolved.
Show resolved Hide resolved

hass.async_create_task(
hass.config_entries.flow.async_init(
self._async_init_reauth(hass, context, data),
f"config entry reauth {self.title} {self.domain} {self.entry_id}",
)

async def _async_init_reauth(
self,
hass: HomeAssistant,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
async with self._reauth_lock:
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
# Reauth flow already in progress for this entry
return
await hass.config_entries.flow.async_init(
self.domain,
context={
"source": SOURCE_REAUTH,
Expand All @@ -742,9 +761,7 @@ def async_start_reauth(
}
| (context or {}),
data=self.data | (data or {}),
),
f"config entry reauth {self.title} {self.domain} {self.entry_id}",
)
)

@callback
def async_get_active_flows(
Expand All @@ -754,7 +771,9 @@ def async_get_active_flows(
return (
flow
for flow in hass.config_entries.flow.async_progress_by_handler(
self.domain, match_context={"entry_id": self.entry_id}
self.domain,
match_context={"entry_id": self.entry_id},
include_uninitialized=True,
)
if flow["context"].get("source") in sources
)
Expand Down
1 change: 1 addition & 0 deletions tests/components/smarttub/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def test_setup_auth_failed(
config_entry.add_to_hass(hass)
with patch.object(hass.config_entries.flow, "async_init") as mock_flow_init:
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert config_entry.state is ConfigEntryState.SETUP_ERROR
mock_flow_init.assert_called_with(
DOMAIN,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3791,6 +3791,20 @@ async def test_reauth(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress()) == 2

# Abort all existing flows
for flow in hass.config_entries.flow.async_progress():
hass.config_entries.flow.async_abort(flow["flow_id"])
await hass.async_block_till_done()

# Check that we can't start duplicate reauth flows
# without blocking between flows
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress()) == 1


async def test_get_active_flows(hass: HomeAssistant) -> None:
"""Test the async_get_active_flows helper."""
Expand Down
Loading