Skip to content

Commit

Permalink
Add product filtering feature to Trafikverket Train (#86343)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjohansson-ST committed Aug 9, 2023
1 parent 0317afe commit 4c03077
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 20 deletions.
12 changes: 10 additions & 2 deletions homeassistant/components/trafikverket_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .coordinator import TVDataUpdateCoordinator


Expand All @@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
f" {entry.data[CONF_TO]}. Error: {error} "
) from error

coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station)
coordinator = TVDataUpdateCoordinator(
hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT)
)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator

Expand All @@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(update_listener))

return True

Expand All @@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Trafikverket Weatherstation config entry."""

return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)


async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
58 changes: 53 additions & 5 deletions homeassistant/components/trafikverket_train/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
Expand All @@ -32,11 +32,15 @@
)
import homeassistant.util.dt as dt_util

from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import create_unique_id, next_departuredate

_LOGGER = logging.getLogger(__name__)

OPTION_SCHEMA = {
vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(),
}

DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): TextSelector(),
Expand All @@ -52,7 +56,7 @@
)
),
}
)
).extend(OPTION_SCHEMA)
DATA_SCHEMA_REAUTH = vol.Schema(
{
vol.Required(CONF_API_KEY): cv.string,
Expand All @@ -67,6 +71,7 @@ async def validate_input(
train_to: str,
train_time: str | None,
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
"""Validate input from user input."""
errors: dict[str, str] = {}
Expand All @@ -87,9 +92,13 @@ async def validate_input(
from_station = await train_api.async_get_train_station(train_from)
to_station = await train_api.async_get_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(from_station, to_station, when)
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(from_station, to_station, when)
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
Expand Down Expand Up @@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):

entry: config_entries.ConfigEntry | None

@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> TVTrainOptionsFlowHandler:
"""Get the options flow for this handler."""
return TVTrainOptionsFlowHandler(config_entry)

async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle re-authentication with Trafikverket."""

Expand All @@ -140,6 +157,7 @@ async def async_step_reauth_confirm(
self.entry.data[CONF_TO],
self.entry.data.get(CONF_TIME),
self.entry.data[CONF_WEEKDAY],
self.entry.options.get(CONF_FILTER_PRODUCT),
)
if not errors:
self.hass.config_entries.async_update_entry(
Expand Down Expand Up @@ -170,6 +188,10 @@ async def async_step_user(
train_to: str = user_input[CONF_TO]
train_time: str | None = user_input.get(CONF_TIME)
train_days: list = user_input[CONF_WEEKDAY]
filter_product: str | None = user_input[CONF_FILTER_PRODUCT]

if filter_product == "":
filter_product = None

name = f"{train_from} to {train_to}"
if train_time:
Expand All @@ -182,6 +204,7 @@ async def async_step_user(
train_to,
train_time,
train_days,
filter_product,
)
if not errors:
unique_id = create_unique_id(
Expand All @@ -199,6 +222,7 @@ async def async_step_user(
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)

return self.async_show_form(
Expand All @@ -208,3 +232,27 @@ async def async_step_user(
),
errors=errors,
)


class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
"""Handle Trafikverket Train options."""

async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage Trafikverket Train options."""
errors: dict[str, Any] = {}

if user_input:
if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "":
user_input[CONF_FILTER_PRODUCT] = None
return self.async_create_entry(data=user_input)

return self.async_show_form(
step_id="init",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(OPTION_SCHEMA),
user_input or self.options,
),
errors=errors,
)
1 change: 1 addition & 0 deletions homeassistant/components/trafikverket_train/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
CONF_FROM = "from"
CONF_TO = "to"
CONF_TIME = "time"
CONF_FILTER_PRODUCT = "filter_product"
8 changes: 6 additions & 2 deletions homeassistant/components/trafikverket_train/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TrainData:
actual_time: datetime | None
other_info: str | None
deviation: str | None
product_filter: str | None


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
entry: ConfigEntry,
to_station: StationInfo,
from_station: StationInfo,
filter_product: str | None,
) -> None:
"""Initialize the Trafikverket coordinator."""
super().__init__(
Expand All @@ -83,6 +85,7 @@ def __init__(
self.to_station: StationInfo = to_station
self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME])
self._weekdays: list[str] = entry.data[CONF_WEEKDAY]
self._filter_product = filter_product

async def _async_update_data(self) -> TrainData:
"""Fetch data from Trafikverket."""
Expand All @@ -99,11 +102,11 @@ async def _async_update_data(self) -> TrainData:
try:
if self._time:
state = await self._train_api.async_get_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
else:
state = await self._train_api.async_get_next_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
Expand Down Expand Up @@ -134,6 +137,7 @@ async def _async_update_data(self) -> TrainData:
actual_time=_get_as_utc(state.time_at_location),
other_info=_get_as_joined(state.other_information),
deviation=_get_as_joined(state.deviations),
product_filter=self._filter_product,
)

return states
12 changes: 11 additions & 1 deletion homeassistant/components/trafikverket_train/sensor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Train information for departures and delays, provided by Trafikverket."""
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any

from homeassistant.components.sensor import (
SensorDeviceClass,
Expand All @@ -22,6 +23,8 @@
from .const import ATTRIBUTION, DOMAIN
from .coordinator import TrainData, TVDataUpdateCoordinator

ATTR_PRODUCT_FILTER = "product_filter"


@dataclass
class TrafikverketRequiredKeysMixin:
Expand Down Expand Up @@ -158,3 +161,10 @@ def _update_attr(self) -> None:
def _handle_coordinator_update(self) -> None:
self._update_attr()
return super()._handle_coordinator_update()

@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return additional attributes for Trafikverket Train sensor."""
if self.coordinator.data.product_filter:
return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter}
return None
79 changes: 69 additions & 10 deletions homeassistant/components/trafikverket_train/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
"to": "To station",
"from": "From station",
"time": "Time (optional)",
"weekday": "Days"
"weekday": "Days",
"filter_product": "Filter by product description"
},
"data_description": {
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure"
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure",
"filter_product": "To filter by product description add the phrase here to match"
}
},
"reauth_confirm": {
Expand All @@ -33,6 +35,18 @@
}
}
},
"options": {
"step": {
"init": {
"data": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]"
},
"data_description": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]"
}
}
}
},
"selector": {
"weekday": {
"options": {
Expand All @@ -49,36 +63,81 @@
"entity": {
"sensor": {
"departure_time": {
"name": "Departure time"
"name": "Departure time",
"state_attributes": {
"product_filter": {
"name": "Train filtering"
}
}
},
"departure_state": {
"name": "Departure state",
"state": {
"on_time": "On time",
"delayed": "Delayed",
"canceled": "Cancelled"
},
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"cancelled": {
"name": "Cancelled"
"name": "Cancelled",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"delayed_time": {
"name": "Delayed time"
"name": "Delayed time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"planned_time": {
"name": "Planned time"
"name": "Planned time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"estimated_time": {
"name": "Estimated time"
"name": "Estimated time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"actual_time": {
"name": "Actual time"
"name": "Actual time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"other_info": {
"name": "Other information"
"name": "Other information",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"deviation": {
"name": "Deviation"
"name": "Deviation",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}
}
}
Expand Down

0 comments on commit 4c03077

Please sign in to comment.