Skip to content

Commit

Permalink
Simplify API, adjust OTBR
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery committed Aug 14, 2023
1 parent 1820ee6 commit f653ef8
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 57 deletions.
30 changes: 12 additions & 18 deletions homeassistant/components/otbr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
import aiohttp
import python_otbr_api

from homeassistant.components.thread import (
async_add_dataset,
async_get_preferred_border_agent_id,
async_get_preferred_dataset,
async_set_preferred_border_agent_id,
)
from homeassistant.components.thread import async_add_dataset
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
Expand Down Expand Up @@ -50,21 +45,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) as err:
raise ConfigEntryNotReady("Unable to connect") from err
if dataset_tlvs:
await update_issues(hass, otbrdata, dataset_tlvs)
await async_add_dataset(hass, DOMAIN, dataset_tlvs.hex())
# If this OTBR's dataset is the preferred one, and there is no preferred router,
# make this the preferred router
border_agent_id: bytes | None = None
border_agent_id: str | None = None
with contextlib.suppress(
HomeAssistantError, aiohttp.ClientError, asyncio.TimeoutError
):
border_agent_id = await otbrdata.get_border_agent_id()
if (
await async_get_preferred_dataset(hass) == dataset_tlvs.hex()
and await async_get_preferred_border_agent_id(hass) is None
and border_agent_id
):
await async_set_preferred_border_agent_id(hass, border_agent_id.hex())
border_agent_bytes = await otbrdata.get_border_agent_id()
if border_agent_bytes:
border_agent_id = border_agent_bytes.hex()
await update_issues(hass, otbrdata, dataset_tlvs)
await async_add_dataset(
hass,
DOMAIN,
dataset_tlvs.hex(),
preferred_border_agent_id=border_agent_id,
)

entry.async_on_unload(entry.add_update_listener(async_reload_entry))

Expand Down
2 changes: 0 additions & 2 deletions homeassistant/components/thread/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
async_add_dataset,
async_get_dataset,
async_get_preferred_dataset,
async_set_preferred_dataset_preferred_border_agent_id,
)
from .websocket_api import async_setup as async_setup_ws_api

Expand All @@ -22,7 +21,6 @@
"async_add_dataset",
"async_get_dataset",
"async_get_preferred_dataset",
"async_set_preferred_dataset_preferred_border_agent_id",
]

CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
Expand Down
30 changes: 14 additions & 16 deletions homeassistant/components/thread/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def __init__(self, hass: HomeAssistant) -> None:
)

@callback
def async_add(self, source: str, tlv: str) -> None:
def async_add(
self, source: str, tlv: str, preferred_border_agent_id: str | None
) -> None:
"""Add dataset, does nothing if it already exists."""
# Make sure the tlv is valid
dataset = tlv_parser.parse_tlv(tlv)
Expand Down Expand Up @@ -248,7 +250,9 @@ def async_add(self, source: str, tlv: str) -> None:
self.async_schedule_save()
return

entry = DatasetEntry(preferred_border_agent_id=None, source=source, tlv=tlv)
entry = DatasetEntry(
preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv
)
self.datasets[entry.id] = entry
# Set to preferred if there is no preferred dataset
if self._preferred_dataset is None:
Expand Down Expand Up @@ -337,10 +341,16 @@ async def async_get_store(hass: HomeAssistant) -> DatasetStore:
return store


async def async_add_dataset(hass: HomeAssistant, source: str, tlv: str) -> None:
async def async_add_dataset(
hass: HomeAssistant,
source: str,
tlv: str,
*,
preferred_border_agent_id: str | None = None,
) -> None:
"""Add a dataset."""
store = await async_get_store(hass)
store.async_add(source, tlv)
store.async_add(source, tlv, preferred_border_agent_id)


async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None:
Expand All @@ -359,15 +369,3 @@ async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None:
) is None:
return None
return entry.tlv


async def async_set_preferred_dataset_preferred_border_agent_id(
hass: HomeAssistant, border_agent_id: str
) -> None:
"""Set the preferred border agent ID of the preferred dataset."""
store = await async_get_store(hass)
if (preferred_dataset_id := store.preferred_dataset) is None or (
store.async_get(preferred_dataset_id)
) is None:
raise HomeAssistantError("UnknownDataset")
store.async_set_preferred_border_agent_id(preferred_dataset_id, border_agent_id)
6 changes: 3 additions & 3 deletions tests/components/assist_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def load_homeassistant(hass) -> None:
assert await async_setup_component(hass, "homeassistant", {})


async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
async def test_load_pipelines(hass: HomeAssistant, init_components) -> None:
"""Make sure that we can load/save data correctly."""

pipelines = [
Expand Down Expand Up @@ -92,10 +92,10 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()


async def test_loading_datasets_from_storage(
async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored datasets on start."""
"""Test loading stored pipelines on start."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
Expand Down
10 changes: 5 additions & 5 deletions tests/components/otbr/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
async def test_import_dataset(hass: HomeAssistant) -> None:
"""Test the active dataset is imported at setup."""
issue_registry = ir.async_get(hass)
assert await thread.async_get_preferred_border_agent_id(hass) is None
assert await thread.async_get_preferred_dataset(hass) is None

config_entry = MockConfigEntry(
Expand All @@ -54,8 +53,9 @@ async def test_import_dataset(hass: HomeAssistant) -> None:
):
assert await hass.config_entries.async_setup(config_entry.entry_id)

dataset_store = await thread.dataset_store.async_get_store(hass)
assert (
await thread.async_get_preferred_border_agent_id(hass)
list(dataset_store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await thread.async_get_preferred_dataset(hass) == DATASET_CH16.hex()
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_import_share_radio_channel_collision(
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None)
assert issue_registry.async_get_issue(
domain=otbr.DOMAIN,
issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}",
Expand Down Expand Up @@ -127,7 +127,7 @@ async def test_import_share_radio_no_channel_collision(
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None)
assert not issue_registry.async_get_issue(
domain=otbr.DOMAIN,
issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}",
Expand Down Expand Up @@ -158,7 +158,7 @@ async def test_import_insecure_dataset(hass: HomeAssistant, dataset: bytes) -> N
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None)
assert issue_registry.async_get_issue(
domain=otbr.DOMAIN, issue_id=f"insecure_thread_network_{config_entry.entry_id}"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/components/otbr/test_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def test_create_network(
assert set_enabled_mock.mock_calls[0][1][0] is False
assert set_enabled_mock.mock_calls[1][1][0] is True
get_active_dataset_tlvs_mock.assert_called_once()
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None)


async def test_create_network_no_entry(
Expand Down
17 changes: 5 additions & 12 deletions tests/components/thread/test_dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async def test_load_datasets(hass: HomeAssistant) -> None:

store1 = await dataset_store.async_get_store(hass)
for dataset in datasets:
store1.async_add(dataset["source"], dataset["tlv"])
store1.async_add(dataset["source"], dataset["tlv"], None)
assert len(store1.datasets) == 3

for dataset in store1.datasets.values():
Expand Down Expand Up @@ -543,19 +543,12 @@ async def test_migrate_set_default_border_agent_id(

async def test_set_preferred_border_agent_id(hass: HomeAssistant) -> None:
"""Test set the preferred border agent ID of a dataset."""
with pytest.raises(HomeAssistantError):
await dataset_store.async_set_preferred_dataset_preferred_border_agent_id(
hass, "blah"
)
assert await dataset_store.async_get_preferred_dataset(hass) is None

await dataset_store.async_add_dataset(hass, "source", DATASET_1)
await dataset_store.async_add_dataset(
hass, "source", DATASET_1, preferred_border_agent_id="blah"
)

store = await dataset_store.async_get_store(hass)
assert len(store.datasets) == 1
assert list(store.datasets.values())[0].preferred_border_agent_id is None

await dataset_store.async_set_preferred_dataset_preferred_border_agent_id(
hass, "blah"
)

assert list(store.datasets.values())[0].preferred_border_agent_id == "blah"

0 comments on commit f653ef8

Please sign in to comment.