Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store preferred border agent ID for each thread dataset #98384

Merged
merged 3 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions homeassistant/components/otbr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
import aiohttp
import python_otbr_api

from homeassistant.components.thread import (
async_add_dataset,
async_get_preferred_border_agent_id,
async_get_preferred_dataset,
async_set_preferred_border_agent_id,
)
from homeassistant.components.thread import async_add_dataset
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
Expand Down Expand Up @@ -50,21 +45,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) as err:
raise ConfigEntryNotReady("Unable to connect") from err
if dataset_tlvs:
await update_issues(hass, otbrdata, dataset_tlvs)
await async_add_dataset(hass, DOMAIN, dataset_tlvs.hex())
# If this OTBR's dataset is the preferred one, and there is no preferred router,
# make this the preferred router
border_agent_id: bytes | None = None
border_agent_id: str | None = None
with contextlib.suppress(
HomeAssistantError, aiohttp.ClientError, asyncio.TimeoutError
):
border_agent_id = await otbrdata.get_border_agent_id()
if (
await async_get_preferred_dataset(hass) == dataset_tlvs.hex()
and await async_get_preferred_border_agent_id(hass) is None
and border_agent_id
):
await async_set_preferred_border_agent_id(hass, border_agent_id.hex())
border_agent_bytes = await otbrdata.get_border_agent_id()
if border_agent_bytes:
border_agent_id = border_agent_bytes.hex()
await update_issues(hass, otbrdata, dataset_tlvs)
await async_add_dataset(
hass,
DOMAIN,
dataset_tlvs.hex(),
preferred_border_agent_id=border_agent_id,
)

entry.async_on_unload(entry.add_update_listener(async_reload_entry))

Expand Down
4 changes: 0 additions & 4 deletions homeassistant/components/thread/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
DatasetEntry,
async_add_dataset,
async_get_dataset,
async_get_preferred_border_agent_id,
async_get_preferred_dataset,
async_set_preferred_border_agent_id,
)
from .websocket_api import async_setup as async_setup_ws_api

__all__ = [
"DOMAIN",
"DatasetEntry",
"async_add_dataset",
"async_get_preferred_border_agent_id",
"async_get_dataset",
"async_get_preferred_dataset",
"async_set_preferred_border_agent_id",
]

CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
Expand Down
59 changes: 27 additions & 32 deletions homeassistant/components/thread/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DatasetPreferredError(HomeAssistantError):
class DatasetEntry:
"""Dataset store entry."""

preferred_border_agent_id: str | None
source: str
tlv: str

Expand Down Expand Up @@ -73,6 +74,7 @@ def to_json(self) -> dict[str, Any]:
return {
"created": self.created.isoformat(),
"id": self.id,
"preferred_border_agent_id": self.preferred_border_agent_id,
"source": self.source,
"tlv": self.tlv,
}
Expand All @@ -97,6 +99,7 @@ async def _async_migrate_func(
entry = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=None,
source=dataset["source"],
tlv=dataset["tlv"],
)
Expand Down Expand Up @@ -160,7 +163,8 @@ async def _async_migrate_func(
}
if old_minor_version < 3:
# Add border agent ID
data.setdefault("preferred_border_agent_id", None)
for dataset in data["datasets"]:
dataset.setdefault("preferred_border_agent_id", None)

return data

Expand All @@ -172,7 +176,6 @@ def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the dataset store."""
self.hass = hass
self.datasets: dict[str, DatasetEntry] = {}
self._preferred_border_agent_id: str | None = None
self._preferred_dataset: str | None = None
self._store: Store[dict[str, Any]] = DatasetStoreStore(
hass,
Expand All @@ -183,7 +186,9 @@ def __init__(self, hass: HomeAssistant) -> None:
)

@callback
def async_add(self, source: str, tlv: str) -> None:
def async_add(
self, source: str, tlv: str, preferred_border_agent_id: str | None
) -> None:
"""Add dataset, does nothing if it already exists."""
# Make sure the tlv is valid
dataset = tlv_parser.parse_tlv(tlv)
Expand Down Expand Up @@ -245,7 +250,9 @@ def async_add(self, source: str, tlv: str) -> None:
self.async_schedule_save()
return

entry = DatasetEntry(source=source, tlv=tlv)
entry = DatasetEntry(
preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv
)
self.datasets[entry.id] = entry
# Set to preferred if there is no preferred dataset
if self._preferred_dataset is None:
Expand All @@ -266,14 +273,13 @@ def async_get(self, dataset_id: str) -> DatasetEntry | None:
return self.datasets.get(dataset_id)

@callback
def async_get_preferred_border_agent_id(self) -> str | None:
"""Get preferred border agent id."""
return self._preferred_border_agent_id

@callback
def async_set_preferred_border_agent_id(self, border_agent_id: str) -> None:
"""Set preferred border agent id."""
self._preferred_border_agent_id = border_agent_id
def async_set_preferred_border_agent_id(
self, dataset_id: str, border_agent_id: str
) -> None:
"""Set preferred border agent id of a dataset."""
self.datasets[dataset_id] = dataclasses.replace(
self.datasets[dataset_id], preferred_border_agent_id=border_agent_id
)
self.async_schedule_save()

@property
Expand All @@ -296,7 +302,6 @@ async def async_load(self) -> None:
data = await self._store.async_load()

datasets: dict[str, DatasetEntry] = {}
preferred_border_agent_id: str | None = None
preferred_dataset: str | None = None

if data is not None:
Expand All @@ -305,14 +310,13 @@ async def async_load(self) -> None:
datasets[dataset["id"]] = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=dataset["preferred_border_agent_id"],
source=dataset["source"],
tlv=dataset["tlv"],
)
preferred_border_agent_id = data["preferred_border_agent_id"]
preferred_dataset = data["preferred_dataset"]

self.datasets = datasets
self._preferred_border_agent_id = preferred_border_agent_id
self._preferred_dataset = preferred_dataset

@callback
Expand All @@ -325,7 +329,6 @@ def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of datasets to store in a file."""
data: dict[str, Any] = {}
data["datasets"] = [dataset.to_json() for dataset in self.datasets.values()]
data["preferred_border_agent_id"] = self._preferred_border_agent_id
data["preferred_dataset"] = self._preferred_dataset
return data

Expand All @@ -338,10 +341,16 @@ async def async_get_store(hass: HomeAssistant) -> DatasetStore:
return store


async def async_add_dataset(hass: HomeAssistant, source: str, tlv: str) -> None:
async def async_add_dataset(
hass: HomeAssistant,
source: str,
tlv: str,
*,
preferred_border_agent_id: str | None = None,
) -> None:
"""Add a dataset."""
store = await async_get_store(hass)
store.async_add(source, tlv)
store.async_add(source, tlv, preferred_border_agent_id)


async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None:
Expand All @@ -352,20 +361,6 @@ async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None:
return entry.tlv


async def async_get_preferred_border_agent_id(hass: HomeAssistant) -> str | None:
"""Get the preferred border agent ID."""
store = await async_get_store(hass)
return store.async_get_preferred_border_agent_id()


async def async_set_preferred_border_agent_id(
hass: HomeAssistant, border_agent_id: str
) -> None:
"""Get the preferred border agent ID."""
store = await async_get_store(hass)
store.async_set_preferred_border_agent_id(border_agent_id)


async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None:
"""Get the preferred dataset."""
store = await async_get_store(hass)
Expand Down
22 changes: 5 additions & 17 deletions homeassistant/components/thread/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, ws_discover_routers)
websocket_api.async_register_command(hass, ws_get_dataset)
websocket_api.async_register_command(hass, ws_list_datasets)
websocket_api.async_register_command(hass, ws_get_preferred_border_agent_id)
websocket_api.async_register_command(hass, ws_set_preferred_border_agent_id)
websocket_api.async_register_command(hass, ws_set_preferred_dataset)

Expand Down Expand Up @@ -52,25 +51,11 @@ async def ws_add_dataset(
connection.send_result(msg["id"])


@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "thread/get_preferred_border_agent_id",
}
)
@websocket_api.async_response
async def ws_get_preferred_border_agent_id(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Get the preferred border agent ID."""
border_agent_id = await dataset_store.async_get_preferred_border_agent_id(hass)
connection.send_result(msg["id"], {"border_agent_id": border_agent_id})


@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "thread/set_preferred_border_agent_id",
vol.Required("dataset_id"): str,
vol.Required("border_agent_id"): str,
}
)
Expand All @@ -79,8 +64,10 @@ async def ws_set_preferred_border_agent_id(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Set the preferred border agent ID."""
dataset_id = msg["dataset_id"]
border_agent_id = msg["border_agent_id"]
await dataset_store.async_set_preferred_border_agent_id(hass, border_agent_id)
store = await dataset_store.async_get_store(hass)
store.async_set_preferred_border_agent_id(dataset_id, border_agent_id)
connection.send_result(msg["id"])


Expand Down Expand Up @@ -186,6 +173,7 @@ async def ws_list_datasets(
"network_name": dataset.network_name,
"pan_id": dataset.pan_id,
"preferred": dataset.id == preferred_dataset,
"preferred_border_agent_id": dataset.preferred_border_agent_id,
"source": dataset.source,
}
)
Expand Down
10 changes: 5 additions & 5 deletions tests/components/otbr/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
async def test_import_dataset(hass: HomeAssistant) -> None:
"""Test the active dataset is imported at setup."""
issue_registry = ir.async_get(hass)
assert await thread.async_get_preferred_border_agent_id(hass) is None
assert await thread.async_get_preferred_dataset(hass) is None

config_entry = MockConfigEntry(
Expand All @@ -54,8 +53,9 @@ async def test_import_dataset(hass: HomeAssistant) -> None:
):
assert await hass.config_entries.async_setup(config_entry.entry_id)

dataset_store = await thread.dataset_store.async_get_store(hass)
assert (
await thread.async_get_preferred_border_agent_id(hass)
list(dataset_store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await thread.async_get_preferred_dataset(hass) == DATASET_CH16.hex()
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_import_share_radio_channel_collision(
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None)
assert issue_registry.async_get_issue(
domain=otbr.DOMAIN,
issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}",
Expand Down Expand Up @@ -127,7 +127,7 @@ async def test_import_share_radio_no_channel_collision(
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None)
assert not issue_registry.async_get_issue(
domain=otbr.DOMAIN,
issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}",
Expand Down Expand Up @@ -158,7 +158,7 @@ async def test_import_insecure_dataset(hass: HomeAssistant, dataset: bytes) -> N
) as mock_add:
assert await hass.config_entries.async_setup(config_entry.entry_id)

mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None)
assert issue_registry.async_get_issue(
domain=otbr.DOMAIN, issue_id=f"insecure_thread_network_{config_entry.entry_id}"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/components/otbr/test_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def test_create_network(
assert set_enabled_mock.mock_calls[0][1][0] is False
assert set_enabled_mock.mock_calls[1][1][0] is True
get_active_dataset_tlvs_mock.assert_called_once()
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex())
mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None)


async def test_create_network_no_entry(
Expand Down
Loading
Loading