Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add product filtering feature to Trafikverket Train #86343

Merged
merged 14 commits into from
Aug 9, 2023
12 changes: 10 additions & 2 deletions homeassistant/components/trafikverket_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .coordinator import TVDataUpdateCoordinator


Expand All @@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
f" {entry.data[CONF_TO]}. Error: {error} "
) from error

coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station)
coordinator = TVDataUpdateCoordinator(
hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT)
)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator

Expand All @@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(update_listener))

return True

Expand All @@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Trafikverket Weatherstation config entry."""

return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)


async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
58 changes: 53 additions & 5 deletions homeassistant/components/trafikverket_train/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
Expand All @@ -32,11 +32,15 @@
)
import homeassistant.util.dt as dt_util

from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import create_unique_id, next_departuredate

_LOGGER = logging.getLogger(__name__)

OPTION_SCHEMA = {
vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(),
}

DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): TextSelector(),
Expand All @@ -52,7 +56,7 @@
)
),
}
)
).extend(OPTION_SCHEMA)
DATA_SCHEMA_REAUTH = vol.Schema(
{
vol.Required(CONF_API_KEY): cv.string,
Expand All @@ -67,6 +71,7 @@ async def validate_input(
train_to: str,
train_time: str | None,
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
"""Validate input from user input."""
errors: dict[str, str] = {}
Expand All @@ -87,9 +92,13 @@ async def validate_input(
from_station = await train_api.async_get_train_station(train_from)
to_station = await train_api.async_get_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(from_station, to_station, when)
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(from_station, to_station, when)
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
Expand Down Expand Up @@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):

entry: config_entries.ConfigEntry | None

@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> TVTrainOptionsFlowHandler:
"""Get the options flow for this handler."""
return TVTrainOptionsFlowHandler(config_entry)

async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle re-authentication with Trafikverket."""

Expand All @@ -140,6 +157,7 @@ async def async_step_reauth_confirm(
self.entry.data[CONF_TO],
self.entry.data.get(CONF_TIME),
self.entry.data[CONF_WEEKDAY],
self.entry.options.get(CONF_FILTER_PRODUCT),
)
if not errors:
self.hass.config_entries.async_update_entry(
Expand Down Expand Up @@ -170,6 +188,10 @@ async def async_step_user(
train_to: str = user_input[CONF_TO]
train_time: str | None = user_input.get(CONF_TIME)
train_days: list = user_input[CONF_WEEKDAY]
filter_product: str | None = user_input[CONF_FILTER_PRODUCT]

if filter_product == "":
filter_product = None

name = f"{train_from} to {train_to}"
if train_time:
Expand All @@ -182,6 +204,7 @@ async def async_step_user(
train_to,
train_time,
train_days,
filter_product,
)
if not errors:
unique_id = create_unique_id(
Expand All @@ -199,6 +222,7 @@ async def async_step_user(
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)

return self.async_show_form(
Expand All @@ -208,3 +232,27 @@ async def async_step_user(
),
errors=errors,
)


class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
"""Handle Trafikverket Train options."""

async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage Trafikverket Train options."""
errors: dict[str, Any] = {}

if user_input:
if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "":
user_input[CONF_FILTER_PRODUCT] = None
return self.async_create_entry(data=user_input)

return self.async_show_form(
step_id="init",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(OPTION_SCHEMA),
user_input or self.options,
),
errors=errors,
)
1 change: 1 addition & 0 deletions homeassistant/components/trafikverket_train/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
CONF_FROM = "from"
CONF_TO = "to"
CONF_TIME = "time"
CONF_FILTER_PRODUCT = "filter_product"
8 changes: 6 additions & 2 deletions homeassistant/components/trafikverket_train/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TrainData:
actual_time: datetime | None
other_info: str | None
deviation: str | None
product_filter: str | None


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
entry: ConfigEntry,
to_station: StationInfo,
from_station: StationInfo,
filter_product: str | None,
) -> None:
"""Initialize the Trafikverket coordinator."""
super().__init__(
Expand All @@ -83,6 +85,7 @@ def __init__(
self.to_station: StationInfo = to_station
self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME])
self._weekdays: list[str] = entry.data[CONF_WEEKDAY]
self._filter_product = filter_product

async def _async_update_data(self) -> TrainData:
"""Fetch data from Trafikverket."""
Expand All @@ -99,11 +102,11 @@ async def _async_update_data(self) -> TrainData:
try:
if self._time:
state = await self._train_api.async_get_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
else:
state = await self._train_api.async_get_next_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
Expand Down Expand Up @@ -134,6 +137,7 @@ async def _async_update_data(self) -> TrainData:
actual_time=_get_as_utc(state.time_at_location),
other_info=_get_as_joined(state.other_information),
deviation=_get_as_joined(state.deviations),
product_filter=self._filter_product,
)

return states
12 changes: 11 additions & 1 deletion homeassistant/components/trafikverket_train/sensor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Train information for departures and delays, provided by Trafikverket."""
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any

from homeassistant.components.sensor import (
SensorDeviceClass,
Expand All @@ -22,6 +23,8 @@
from .const import ATTRIBUTION, DOMAIN
from .coordinator import TrainData, TVDataUpdateCoordinator

ATTR_PRODUCT_FILTER = "product_filter"


@dataclass
class TrafikverketRequiredKeysMixin:
Expand Down Expand Up @@ -158,3 +161,10 @@ def _update_attr(self) -> None:
def _handle_coordinator_update(self) -> None:
self._update_attr()
return super()._handle_coordinator_update()

@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return additional attributes for Trafikverket Train sensor."""
if self.coordinator.data.product_filter:
return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter}
return None
79 changes: 69 additions & 10 deletions homeassistant/components/trafikverket_train/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"to": "To station",
"from": "From station",
"time": "Time (optional)",
"weekday": "Days"
"weekday": "Days",
"filter_product": "Filter by product description"
},
"data_description": {
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure"
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure",
"filter_product": "To filter by product description add the phrase here to match"
}
},
"reauth_confirm": {
Expand All @@ -33,6 +35,18 @@
}
}
},
"options": {
"step": {
"init": {
"data": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]"
},
"data_description": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]"
}
}
}
},
"selector": {
"weekday": {
"options": {
Expand All @@ -49,36 +63,81 @@
"entity": {
"sensor": {
"departure_time": {
"name": "Departure time"
"name": "Departure time",
"state_attributes": {
"product_filter": {
"name": "Train filtering"
}
}
},
"departure_state": {
"name": "Departure state",
"state": {
"on_time": "On time",
"delayed": "Delayed",
"canceled": "Cancelled"
},
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"cancelled": {
"name": "Cancelled"
"name": "Cancelled",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"delayed_time": {
"name": "Delayed time"
"name": "Delayed time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"planned_time": {
"name": "Planned time"
"name": "Planned time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"estimated_time": {
"name": "Estimated time"
"name": "Estimated time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"actual_time": {
"name": "Actual time"
"name": "Actual time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"other_info": {
"name": "Other information"
"name": "Other information",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"deviation": {
"name": "Deviation"
"name": "Deviation",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}
}
}
Expand Down