Skip to content

Commit

Permalink
Add Google Generative AI reauth flow (#118096)
Browse files Browse the repository at this point in the history
* Add reauth flow

* address comments
  • Loading branch information
tronikos authored May 26, 2024
1 parent b85cf36 commit 0972b29
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from __future__ import annotations

from asyncio import timeout
from functools import partial
import mimetypes
from pathlib import Path

from google.api_core.exceptions import ClientError
from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPICallError
import google.generativeai as genai
import google.generativeai.types as genai_types
import voluptuous as vol
Expand All @@ -20,11 +19,16 @@
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
HomeAssistantError,
)
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType

from .const import CONF_PROMPT, DOMAIN, LOGGER
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL

SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
Expand Down Expand Up @@ -101,13 +105,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
genai.configure(api_key=entry.data[CONF_API_KEY])

try:
async with timeout(5.0):
next(await hass.async_add_executor_job(partial(genai.list_models)), None)
except (ClientError, TimeoutError) as err:
await hass.async_add_executor_job(
partial(
genai.get_model,
entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
request_options={"timeout": 5.0},
)
)
except (GoogleAPICallError, ValueError) as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
LOGGER.error("Invalid API key: %s", err)
return False
raise ConfigEntryNotReady(err) from err
raise ConfigEntryAuthFailed(err) from err
if isinstance(err, DeadlineExceeded):
raise ConfigEntryNotReady(err) from err
raise ConfigEntryError(err) from err

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

from collections.abc import Mapping
from functools import partial
import logging
from types import MappingProxyType
from typing import Any

from google.api_core.exceptions import ClientError
from google.api_core.exceptions import ClientError, GoogleAPICallError
import google.generativeai as genai
import voluptuous as vol

Expand All @@ -17,7 +18,7 @@
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
Expand Down Expand Up @@ -54,7 +55,7 @@

_LOGGER = logging.getLogger(__name__)

STEP_USER_DATA_SCHEMA = vol.Schema(
STEP_API_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): str,
}
Expand All @@ -73,44 +74,86 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
genai.configure(api_key=data[CONF_API_KEY])
await hass.async_add_executor_job(partial(genai.list_models))

def get_first_model():
return next(genai.list_models(request_options={"timeout": 5.0}), None)

await hass.async_add_executor_job(partial(get_first_model))


class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Google Generative AI Conversation."""

VERSION = 1

async def async_step_user(
def __init__(self) -> None:
"""Initialize a new GoogleGenerativeAIConfigFlow."""
self.reauth_entry: ConfigEntry | None = None

async def async_step_api(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
if user_input is None:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
)

errors = {}

try:
await validate_input(self.hass, user_input)
except ClientError as err:
if err.reason == "API_KEY_INVALID":
errors["base"] = "invalid_auth"
errors: dict[str, str] = {}
if user_input is not None:
try:
await validate_input(self.hass, user_input)
except GoogleAPICallError as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
errors["base"] = "invalid_auth"
else:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
return self.async_create_entry(
title="Google Generative AI",
data=user_input,
options=RECOMMENDED_OPTIONS,
)
if self.reauth_entry:
return self.async_update_reload_and_abort(
self.reauth_entry,
data=user_input,
)
return self.async_create_entry(
title="Google Generative AI",
data=user_input,
options=RECOMMENDED_OPTIONS,
)
return self.async_show_form(
step_id="api",
data_schema=STEP_API_DATA_SCHEMA,
description_placeholders={
"api_key_url": "https://aistudio.google.com/app/apikey"
},
errors=errors,
)

async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
return await self.async_step_api()

async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Handle configuration by re-auth."""
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()

async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Dialog that informs the user that reauth is required."""
if user_input is not None:
return await self.async_step_api()
assert self.reauth_entry
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
step_id="reauth_confirm",
description_placeholders={
CONF_NAME: self.reauth_entry.title,
CONF_API_KEY: self.reauth_entry.data.get(CONF_API_KEY, ""),
},
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
{
"config": {
"step": {
"user": {
"api": {
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}
"api_key": "[%key:common::config_flow::data::api_key%]"
},
"description": "Get your API key from [here]({api_key_url})."
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "Your current API key: {api_key} is no longer valid. Please enter a new valid API key."
}
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
},
"abort": {
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
}
},
"options": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
def mock_genai():
"""Mock the genai call in async_setup_entry."""
with patch(
"homeassistant.components.google_generative_ai_conversation.genai.list_models",
return_value=iter([]),
"homeassistant.components.google_generative_ai_conversation.genai.get_model"
):
yield

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unittest.mock import Mock, patch

from google.api_core.exceptions import ClientError
from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo
import pytest

Expand Down Expand Up @@ -69,7 +69,7 @@ async def test_form(hass: HomeAssistant) -> None:
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] is None
assert not result["errors"]

with (
patch(
Expand Down Expand Up @@ -186,13 +186,16 @@ async def test_options_switching(
("side_effect", "error"),
[
(
ClientError(message="some error"),
ClientError("some error"),
"cannot_connect",
),
(
DeadlineExceeded("deadline exceeded"),
"cannot_connect",
),
(
ClientError(
message="invalid api key",
error_info=ErrorInfo(reason="API_KEY_INVALID"),
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
"invalid_auth",
),
Expand All @@ -218,3 +221,51 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:

assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error}


async def test_reauth_flow(hass: HomeAssistant) -> None:
"""Test the reauth flow."""
hass.config.components.add("google_generative_ai_conversation")
mock_config_entry = MockConfigEntry(
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini"
)
mock_config_entry.add_to_hass(hass)
mock_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()

flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
assert result["context"]["source"] == "reauth"
assert result["context"]["title_placeholders"] == {"name": "Gemini"}

result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "api"
assert "api_key" in result["data_schema"].schema
assert not result["errors"]

with (
patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
),
patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
return_value=True,
) as mock_setup_entry,
patch(
"homeassistant.components.google_generative_ai_conversation.async_unload_entry",
return_value=True,
) as mock_unload_entry,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"api_key": "1234"}
)
await hass.async_block_till_done()

assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert hass.config_entries.async_entries(DOMAIN)[0].data == {"api_key": "1234"}
assert len(mock_unload_entry.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
Loading

0 comments on commit 0972b29

Please sign in to comment.