Skip to content

Commit

Permalink
Refactor re-auth flow post review
Browse files Browse the repository at this point in the history
  • Loading branch information
sdb9696 committed Nov 13, 2023
1 parent 790433c commit 5e9f5f0
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 62 deletions.
81 changes: 33 additions & 48 deletions homeassistant/components/ring/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,68 +49,45 @@ class RingConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
user_pass: dict[str, Any] = {}
reauth_entry: ConfigEntry | None = None

def _show_user_form(self, errors: dict[str, str]) -> FlowResult:
"""Show the user form."""
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)

async def async_step_user(self, user_input=None):
"""Handle the initial step."""
errors = {}
if user_input is not None:
try:
token = await validate_input(self.hass, user_input)
await self.async_set_unique_id(user_input["username"])

return self.async_create_entry(
title=user_input["username"],
data={"username": user_input["username"], "token": token},
)
except Require2FA:
self.user_pass = user_input

return await self.async_step_2fa_user()

return await self.async_step_2fa()
except InvalidAuth:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"

return self._show_user_form(errors)

async def async_step_2fa_user(self, user_input=None):
"""Handle 2fa step."""
if user_input:
return await self.async_step_user({**self.user_pass, **user_input})
else:
await self.async_set_unique_id(user_input["username"])
return self.async_create_entry(
title=user_input["username"],
data={"username": user_input["username"], "token": token},
)

return self.async_show_form(
step_id="2fa_user",
data_schema=vol.Schema({vol.Required("2fa"): str}),
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)

async def async_step_2fa_reauth(self, user_input=None):
async def async_step_2fa(self, user_input=None):
"""Handle 2fa step."""
if user_input:
return await self.async_step_reauth_confirm(
{**self.user_pass, **user_input}
)
if self.reauth_entry:
return await self.async_step_reauth_confirm(
{**self.user_pass, **user_input}
)

return self.async_show_form(
step_id="2fa_reauth",
data_schema=vol.Schema({vol.Required("2fa"): str}),
)
return await self.async_step_user({**self.user_pass, **user_input})

def _show_reauth_form(self, errors: dict[str, str]) -> FlowResult:
"""Show the reauth form."""
return self.async_show_form(
step_id="reauth_confirm",
data_schema=STEP_REAUTH_DATA_SCHEMA,
errors=errors,
description_placeholders={
CONF_USERNAME: self.reauth_entry.data[CONF_USERNAME] # type: ignore[union-attr]
},
step_id="2fa",
data_schema=vol.Schema({vol.Required("2fa"): str}),
)

async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
Expand All @@ -131,6 +108,15 @@ async def async_step_reauth_confirm(
user_input[CONF_USERNAME] = self.reauth_entry.data[CONF_USERNAME]
try:
token = await validate_input(self.hass, user_input)
except Require2FA:
self.user_pass = user_input
return await self.async_step_2fa()
except InvalidAuth:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
data = {
CONF_USERNAME: user_input[CONF_USERNAME],
"token": token,
Expand All @@ -140,16 +126,15 @@ async def async_step_reauth_confirm(
)
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
return self.async_abort(reason="reauth_successful")
except Require2FA:
self.user_pass = user_input
return await self.async_step_2fa_reauth()
except InvalidAuth:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"

return self._show_reauth_form(errors)
return self.async_show_form(
step_id="reauth_confirm",
data_schema=STEP_REAUTH_DATA_SCHEMA,
errors=errors,
description_placeholders={
CONF_USERNAME: self.reauth_entry.data[CONF_USERNAME]
},
)


class Require2FA(exceptions.HomeAssistantError):
Expand Down
8 changes: 1 addition & 7 deletions homeassistant/components/ring/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"password": "[%key:common::config_flow::data::password%]"
}
},
"2fa_user": {
"2fa": {
"title": "Two-factor authentication",
"data": {
"2fa": "Two-factor code"
Expand All @@ -20,12 +20,6 @@
"data": {
"password": "[%key:common::config_flow::data::password%]"
}
},
"2fa_reauth": {
"title": "Two-factor authentication",
"data": {
"2fa": "Two-factor code"
}
}
},
"error": {
Expand Down
31 changes: 26 additions & 5 deletions tests/components/ring/test_config_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test the Ring config flow."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, Mock

import pytest
import ring_doorbell
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_form_2fa(
)

assert result2["type"] == FlowResultType.FORM
assert result2["step_id"] == "2fa_user"
assert result2["step_id"] == "2fa"
mock_ring_auth.fetch_token.reset_mock(side_effect=True)
mock_ring_auth.fetch_token.return_value = "new-foobar"
result3 = await hass.config_entries.flow.async_configure(
Expand Down Expand Up @@ -141,7 +141,7 @@ async def test_reauth(
"foo@bar.com", "other_fake_password", None
)
assert result2["type"] == FlowResultType.FORM
assert result2["step_id"] == "2fa_reauth"
assert result2["step_id"] == "2fa"
mock_ring_auth.fetch_token.reset_mock(side_effect=True)
mock_ring_auth.fetch_token.return_value = "new-foobar"
result3 = await hass.config_entries.flow.async_configure(
Expand Down Expand Up @@ -190,13 +190,34 @@ async def test_reauth_error(
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_PASSWORD: "other_fake_password",
CONF_PASSWORD: "error_fake_password",
},
)
await hass.async_block_till_done()

mock_ring_auth.fetch_token.assert_called_once_with(
"foo@bar.com", "other_fake_password", None
"foo@bar.com", "error_fake_password", None
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": errors_msg}

# Now test reauth can go on to succeed
mock_ring_auth.fetch_token.reset_mock(side_effect=True)
mock_ring_auth.fetch_token.return_value = "new-foobar"
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_PASSWORD: "other_fake_password",
},
)

mock_ring_auth.fetch_token.assert_called_once_with(
"foo@bar.com", "other_fake_password", None
)
assert result3["type"] == FlowResultType.ABORT
assert result3["reason"] == "reauth_successful"
assert mock_added_config_entry.data == {
"username": "foo@bar.com",
"token": "new-foobar",
}
assert len(mock_setup_entry.mock_calls) == 1
4 changes: 2 additions & 2 deletions tests/components/ring/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ async def test_error_on_global_update(
],
ids=["timeout-error", "other-error"],
)
async def test_error_on_device_update(hass: HomeAssistant,
async def test_error_on_device_update(
hass: HomeAssistant,
requests_mock: requests_mock.Mocker,
mock_config_entry: MockConfigEntry,
caplog,
Expand All @@ -190,4 +191,3 @@ async def test_error_on_device_update(hass: HomeAssistant,
record.message for record in caplog.records if record.levelname == "WARNING"
]
assert mock_config_entry.entry_id in hass.data[DOMAIN]

0 comments on commit 5e9f5f0

Please sign in to comment.