Skip to content

Commit

Permalink
Add option to GTFS static parser:
Browse files Browse the repository at this point in the history
  • Loading branch information
jamespfennell committed Jun 3, 2020
1 parent 45d6404 commit fcafc31
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
75 changes: 75 additions & 0 deletions tests/unit/parse/test_gtfsstatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,78 @@ def _create_zip(file_name, file_content):
zip_file.close()
buff.seek(0)
return buff.read()


@pytest.mark.parametrize(
"input_blob,expected",
[
[None, gtfsstaticparser._TransfersConfig()],
[
{"strategy": "group_stations"},
gtfsstaticparser._TransfersConfig(
default_strategy=gtfsstaticparser._TransfersStrategy.GROUP_STATIONS
),
],
[
{"exceptions": [{"strategy": "group_stations", "stop_ids": ["A", "B"]}]},
gtfsstaticparser._TransfersConfig(
exceptions=[
gtfsstaticparser._TransfersConfigException(
strategy=gtfsstaticparser._TransfersStrategy.GROUP_STATIONS,
stop_ids={"A", "B"},
)
]
),
],
],
)
def test_transfers_config__load(input_blob, expected):
actual = gtfsstaticparser._TransfersConfig.load_from_options_blob(input_blob)

assert expected == actual


@pytest.mark.parametrize(
"input_blob",
[
{"unexpected_key": "value"},
{"exceptions": "not_a_list"},
{"exceptions": [{"stop_ids": ["a"]}]},
{"exceptions": [{"strategy": "default"}]},
{"strategy": "unknown"},
{"exceptions": [{"strategy": "unknown", "stop_ids": ["a"]}]},
{
"exceptions": [
{"strategy": "default", "stop_ids": ["a"], "unexpected_key": "value"}
]
},
],
)
def test_transfers_config__load_error(input_blob):
with pytest.raises(Exception):
gtfsstaticparser._TransfersConfig.load_from_options_blob(input_blob)


@pytest.mark.parametrize(
"stop_1_id,stop_2_id,expected",
[
["A", "B", gtfsstaticparser._TransfersStrategy.GROUP_STATIONS],
["A", "C", gtfsstaticparser._TransfersStrategy.GROUP_STATIONS],
["B", "A", gtfsstaticparser._TransfersStrategy.GROUP_STATIONS],
["B", "C", gtfsstaticparser._TransfersStrategy.DEFAULT],
["C", "A", gtfsstaticparser._TransfersStrategy.GROUP_STATIONS],
["C", "B", gtfsstaticparser._TransfersStrategy.DEFAULT],
],
)
def test_transfers_config__get_strategy(stop_1_id, stop_2_id, expected):
config = gtfsstaticparser._TransfersConfig(
default_strategy=gtfsstaticparser._TransfersStrategy.GROUP_STATIONS,
exceptions=[
gtfsstaticparser._TransfersConfigException(
strategy=gtfsstaticparser._TransfersStrategy.DEFAULT,
stop_ids={"B", "C"},
)
],
)

assert config.get_strategy(stop_1_id, stop_2_id) == expected
66 changes: 65 additions & 1 deletion transiter/parse/gtfsstatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import typing
import uuid
import zipfile

import dataclasses
from transiter.parse import types as parse
from transiter.parse.parser import TransiterParser

Expand All @@ -20,6 +20,15 @@ class GtfsStaticParser(TransiterParser):

gtfs_static_file = None

def __init__(self):
super().__init__()
self._transfers_config = _TransfersConfig()

def load_options(self, options_blob: typing.Optional[dict]) -> None:
self._transfers_config = _TransfersConfig.load_from_options_blob(
options_blob.get("transfers")
)

def load_content(self, content: bytes) -> None:
self.gtfs_static_file = _GtfsStaticFile(content)

Expand Down Expand Up @@ -48,6 +57,61 @@ def get_scheduled_services(self) -> typing.Iterable[parse.ScheduledService]:
yield from _parse_schedule(self.gtfs_static_file)


class _TransfersStrategy(enum.Enum):
DEFAULT = 0
GROUP_STATIONS = 1


@dataclasses.dataclass
class _TransfersConfigException:
strategy: _TransfersStrategy
stop_ids: typing.Set[str]


@dataclasses.dataclass
class _TransfersConfig:
default_strategy: _TransfersStrategy = _TransfersStrategy.DEFAULT
exceptions: typing.List[_TransfersConfigException] = dataclasses.field(
default_factory=list
)

@classmethod
def load_from_options_blob(cls, options_blob):
config = cls()
if options_blob is None:
return config
config.default_strategy = _TransfersStrategy[
options_blob.pop("strategy", "DEFAULT").upper()
]
for exception_blob in options_blob.pop("exceptions", []):
config.exceptions.append(
_TransfersConfigException(
strategy=_TransfersStrategy[exception_blob.pop("strategy").upper()],
stop_ids=set(exception_blob.pop("stop_ids")),
)
)
if len(exception_blob) > 0:
raise ValueError(
"Unrecognized transfers.exceptions sub-options: {}".format(
exception_blob
)
)
if len(options_blob) > 0:
raise ValueError(
"Unrecognized transfers sub-options: {}".format(options_blob)
)
return config

def get_strategy(self, stop_1_id, stop_2_id) -> _TransfersStrategy:
for exception in self.exceptions:
if stop_1_id not in exception.stop_ids:
continue
if stop_2_id not in exception.stop_ids:
continue
return exception.strategy
return self.default_strategy


class _GtfsStaticFile:
class _InternalFileName(enum.Enum):
AGENCY = "agency.txt"
Expand Down

0 comments on commit fcafc31

Please sign in to comment.