From 8f5a69870673e2fd1098a917c24b0787ef1b0990 Mon Sep 17 00:00:00 2001 From: G Johansson Date: Sun, 1 Oct 2023 21:29:54 +0200 Subject: [PATCH] Add trainstops --- .gitignore | 4 +- pytrafikverket/trafikverket_train.py | 71 ++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0314240..72520b7 100644 --- a/.gitignore +++ b/.gitignore @@ -367,4 +367,6 @@ venv/ ferry_test.py # Mypy cache -.mypy_cache \ No newline at end of file +.mypy_cache + +test.py diff --git a/pytrafikverket/trafikverket_train.py b/pytrafikverket/trafikverket_train.py index 99e9eec..9df9960 100644 --- a/pytrafikverket/trafikverket_train.py +++ b/pytrafikverket/trafikverket_train.py @@ -268,6 +268,77 @@ async def async_get_train_stop( train_announcement = train_announcements[0] return TrainStop.from_xml_node(train_announcement) + + async def async_get_next_train_stops( + self, + from_station: StationInfo, + to_station: StationInfo, + after_time: datetime, + product_description: str | None = None, + exclude_canceled: bool = False, + number_of_stops: int = 1, + ) -> list[TrainStop]: + """Enable retreival of next departures.""" + date_as_text = after_time.strftime(Trafikverket.date_time_format) + + filters = [ + FieldFilter(FilterOperation.EQUAL, "ActivityType", "Avgang"), + FieldFilter( + FilterOperation.EQUAL, "LocationSignature", from_station.signature + ), + FieldFilter( + FilterOperation.GREATER_THAN_EQUAL, + "AdvertisedTimeAtLocation", + date_as_text, + ), + OrFilter( + [ + FieldFilter( + FilterOperation.EQUAL, + "ViaToLocation.LocationName", + to_station.signature, + ), + FieldFilter( + FilterOperation.EQUAL, + "ToLocation.LocationName", + to_station.signature, + ), + ] + ), + ] + + if product_description: + filters.append( + FieldFilter( + FilterOperation.LIKE, + "ProductInformation.Description", + product_description, + ) + ) + + if exclude_canceled: + filters.append(FieldFilter(FilterOperation.EQUAL, "Canceled", "false")) + + sorting = [FieldSort("AdvertisedTimeAtLocation", SortOrder.ASCENDING)] + train_announcements = await self._api.async_make_request( + "TrainAnnouncement", + "1.8", + TRAIN_STOP_REQUIRED_FIELDS, + filters, + 1, + sorting, + ) + + if len(train_announcements) == 0: + raise NoTrainAnnouncementFound("No TrainAnnouncement found") + + if len(train_announcements) > number_of_stops: + raise MultipleTrainAnnouncementFound("Multiple TrainAnnouncements found") + + stops = [] + for announcement in train_announcements: + stops.append(TrainStop.from_xml_node(announcement)) + return stops async def async_get_next_train_stop( self,