diff --git a/homeassistant/components/trafikverket_train/__init__.py b/homeassistant/components/trafikverket_train/__init__.py index 8f11590c487d06..a7defa2956a60c 100644 --- a/homeassistant/components/trafikverket_train/__init__.py +++ b/homeassistant/components/trafikverket_train/__init__.py @@ -15,7 +15,7 @@ from homeassistant.helpers import entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession -from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS +from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS from .coordinator import TVDataUpdateCoordinator @@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: f" {entry.data[CONF_TO]}. Error: {error} " ) from error - coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station) + coordinator = TVDataUpdateCoordinator( + hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT) + ) await coordinator.async_config_entry_first_refresh() hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator @@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + entry.async_on_unload(entry.add_update_listener(update_listener)) return True @@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Trafikverket Weatherstation config entry.""" return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + + +async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: + """Handle options update.""" + await hass.config_entries.async_reload(entry.entry_id) diff --git a/homeassistant/components/trafikverket_train/config_flow.py b/homeassistant/components/trafikverket_train/config_flow.py index f500085175599a..b7808dc38b29a4 100644 --- a/homeassistant/components/trafikverket_train/config_flow.py +++ b/homeassistant/components/trafikverket_train/config_flow.py @@ -19,7 +19,7 @@ from homeassistant import config_entries from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv @@ -32,11 +32,15 @@ ) import homeassistant.util.dt as dt_util -from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN +from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN from .util import create_unique_id, next_departuredate _LOGGER = logging.getLogger(__name__) +OPTION_SCHEMA = { + vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(), +} + DATA_SCHEMA = vol.Schema( { vol.Required(CONF_API_KEY): TextSelector(), @@ -52,7 +56,7 @@ ) ), } -) +).extend(OPTION_SCHEMA) DATA_SCHEMA_REAUTH = vol.Schema( { vol.Required(CONF_API_KEY): cv.string, @@ -67,6 +71,7 @@ async def validate_input( train_to: str, train_time: str | None, weekdays: list[str], + product_filter: str | None, ) -> dict[str, str]: """Validate input from user input.""" errors: dict[str, str] = {} @@ -87,9 +92,13 @@ async def validate_input( from_station = await train_api.async_get_train_station(train_from) to_station = await train_api.async_get_train_station(train_to) if train_time: - await train_api.async_get_train_stop(from_station, to_station, when) + await train_api.async_get_train_stop( + from_station, to_station, when, product_filter + ) else: - await train_api.async_get_next_train_stop(from_station, to_station, when) + await train_api.async_get_next_train_stop( + from_station, to_station, when, product_filter + ) except InvalidAuthentication: errors["base"] = "invalid_auth" except NoTrainStationFound: @@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): entry: config_entries.ConfigEntry | None + @staticmethod + @callback + def async_get_options_flow( + config_entry: config_entries.ConfigEntry, + ) -> TVTrainOptionsFlowHandler: + """Get the options flow for this handler.""" + return TVTrainOptionsFlowHandler(config_entry) + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: """Handle re-authentication with Trafikverket.""" @@ -140,6 +157,7 @@ async def async_step_reauth_confirm( self.entry.data[CONF_TO], self.entry.data.get(CONF_TIME), self.entry.data[CONF_WEEKDAY], + self.entry.options.get(CONF_FILTER_PRODUCT), ) if not errors: self.hass.config_entries.async_update_entry( @@ -170,6 +188,10 @@ async def async_step_user( train_to: str = user_input[CONF_TO] train_time: str | None = user_input.get(CONF_TIME) train_days: list = user_input[CONF_WEEKDAY] + filter_product: str | None = user_input[CONF_FILTER_PRODUCT] + + if filter_product == "": + filter_product = None name = f"{train_from} to {train_to}" if train_time: @@ -182,6 +204,7 @@ async def async_step_user( train_to, train_time, train_days, + filter_product, ) if not errors: unique_id = create_unique_id( @@ -199,6 +222,7 @@ async def async_step_user( CONF_TIME: train_time, CONF_WEEKDAY: train_days, }, + options={CONF_FILTER_PRODUCT: filter_product}, ) return self.async_show_form( @@ -208,3 +232,27 @@ async def async_step_user( ), errors=errors, ) + + +class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): + """Handle Trafikverket Train options.""" + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Manage Trafikverket Train options.""" + errors: dict[str, Any] = {} + + if user_input: + if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "": + user_input[CONF_FILTER_PRODUCT] = None + return self.async_create_entry(data=user_input) + + return self.async_show_form( + step_id="init", + data_schema=self.add_suggested_values_to_schema( + vol.Schema(OPTION_SCHEMA), + user_input or self.options, + ), + errors=errors, + ) diff --git a/homeassistant/components/trafikverket_train/const.py b/homeassistant/components/trafikverket_train/const.py index 253383b4b5a048..e1852ce9ada78b 100644 --- a/homeassistant/components/trafikverket_train/const.py +++ b/homeassistant/components/trafikverket_train/const.py @@ -8,3 +8,4 @@ CONF_FROM = "from" CONF_TO = "to" CONF_TIME = "time" +CONF_FILTER_PRODUCT = "filter_product" diff --git a/homeassistant/components/trafikverket_train/coordinator.py b/homeassistant/components/trafikverket_train/coordinator.py index fac1c418b0966a..ea852ab7fdf1a6 100644 --- a/homeassistant/components/trafikverket_train/coordinator.py +++ b/homeassistant/components/trafikverket_train/coordinator.py @@ -39,6 +39,7 @@ class TrainData: actual_time: datetime | None other_info: str | None deviation: str | None + product_filter: str | None _LOGGER = logging.getLogger(__name__) @@ -68,6 +69,7 @@ def __init__( entry: ConfigEntry, to_station: StationInfo, from_station: StationInfo, + filter_product: str | None, ) -> None: """Initialize the Trafikverket coordinator.""" super().__init__( @@ -83,6 +85,7 @@ def __init__( self.to_station: StationInfo = to_station self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME]) self._weekdays: list[str] = entry.data[CONF_WEEKDAY] + self._filter_product = filter_product async def _async_update_data(self) -> TrainData: """Fetch data from Trafikverket.""" @@ -99,11 +102,11 @@ async def _async_update_data(self) -> TrainData: try: if self._time: state = await self._train_api.async_get_train_stop( - self.from_station, self.to_station, when + self.from_station, self.to_station, when, self._filter_product ) else: state = await self._train_api.async_get_next_train_stop( - self.from_station, self.to_station, when + self.from_station, self.to_station, when, self._filter_product ) except InvalidAuthentication as error: raise ConfigEntryAuthFailed from error @@ -134,6 +137,7 @@ async def _async_update_data(self) -> TrainData: actual_time=_get_as_utc(state.time_at_location), other_info=_get_as_joined(state.other_information), deviation=_get_as_joined(state.deviations), + product_filter=self._filter_product, ) return states diff --git a/homeassistant/components/trafikverket_train/sensor.py b/homeassistant/components/trafikverket_train/sensor.py index 97d7a6b34fa602..b5f993073a5cbb 100644 --- a/homeassistant/components/trafikverket_train/sensor.py +++ b/homeassistant/components/trafikverket_train/sensor.py @@ -1,9 +1,10 @@ """Train information for departures and delays, provided by Trafikverket.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass from datetime import datetime +from typing import Any from homeassistant.components.sensor import ( SensorDeviceClass, @@ -22,6 +23,8 @@ from .const import ATTRIBUTION, DOMAIN from .coordinator import TrainData, TVDataUpdateCoordinator +ATTR_PRODUCT_FILTER = "product_filter" + @dataclass class TrafikverketRequiredKeysMixin: @@ -158,3 +161,10 @@ def _update_attr(self) -> None: def _handle_coordinator_update(self) -> None: self._update_attr() return super()._handle_coordinator_update() + + @property + def extra_state_attributes(self) -> Mapping[str, Any] | None: + """Return additional attributes for Trafikverket Train sensor.""" + if self.coordinator.data.product_filter: + return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter} + return None diff --git a/homeassistant/components/trafikverket_train/strings.json b/homeassistant/components/trafikverket_train/strings.json index aabab0907abfc3..78d69c880ae744 100644 --- a/homeassistant/components/trafikverket_train/strings.json +++ b/homeassistant/components/trafikverket_train/strings.json @@ -20,10 +20,12 @@ "to": "To station", "from": "From station", "time": "Time (optional)", - "weekday": "Days" + "weekday": "Days", + "filter_product": "Filter by product description" }, "data_description": { - "time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure" + "time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure", + "filter_product": "To filter by product description add the phrase here to match" } }, "reauth_confirm": { @@ -33,6 +35,18 @@ } } }, + "options": { + "step": { + "init": { + "data": { + "filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]" + }, + "data_description": { + "filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]" + } + } + } + }, "selector": { "weekday": { "options": { @@ -49,7 +63,12 @@ "entity": { "sensor": { "departure_time": { - "name": "Departure time" + "name": "Departure time", + "state_attributes": { + "product_filter": { + "name": "Train filtering" + } + } }, "departure_state": { "name": "Departure state", @@ -57,28 +76,68 @@ "on_time": "On time", "delayed": "Delayed", "canceled": "Cancelled" + }, + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } } }, "cancelled": { - "name": "Cancelled" + "name": "Cancelled", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "delayed_time": { - "name": "Delayed time" + "name": "Delayed time", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "planned_time": { - "name": "Planned time" + "name": "Planned time", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "estimated_time": { - "name": "Estimated time" + "name": "Estimated time", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "actual_time": { - "name": "Actual time" + "name": "Actual time", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "other_info": { - "name": "Other information" + "name": "Other information", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } }, "deviation": { - "name": "Deviation" + "name": "Deviation", + "state_attributes": { + "product_filter": { + "name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]" + } + } } } } diff --git a/tests/components/trafikverket_train/test_config_flow.py b/tests/components/trafikverket_train/test_config_flow.py index a3b449755c72ea..3493e031669490 100644 --- a/tests/components/trafikverket_train/test_config_flow.py +++ b/tests/components/trafikverket_train/test_config_flow.py @@ -66,6 +66,7 @@ async def test_form(hass: HomeAssistant) -> None: "time": "10:00", "weekday": ["mon", "fri"], } + assert result["options"] == {"filter_product": None} assert len(mock_setup_entry.mock_calls) == 1 assert result["result"].unique_id == "{}-{}-{}-{}".format( "stockholmc", "uppsalac", "10:00", "['mon', 'fri']" @@ -448,3 +449,55 @@ async def test_reauth_flow_error_departures( "time": "10:00", "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], } + + +async def test_options_flow(hass: HomeAssistant) -> None: + """Test a reauthentication flow.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_API_KEY: "1234567890", + CONF_NAME: "Stockholm C to Uppsala C at 10:00", + CONF_FROM: "Stockholm C", + CONF_TO: "Uppsala C", + CONF_TIME: "10:00", + CONF_WEEKDAY: WEEKDAYS, + }, + unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}", + ) + entry.add_to_hass(hass) + + with patch( + "homeassistant.components.trafikverket_train.async_setup_entry", + return_value=True, + ): + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.options.async_init(entry.entry_id) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "init" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={"filter_product": "SJ Regionaltåg"}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["data"] == {"filter_product": "SJ Regionaltåg"} + + result = await hass.config_entries.options.async_init(entry.entry_id) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "init" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={"filter_product": ""}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["data"] == {"filter_product": None}