Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
jamespfennell committed May 28, 2020
1 parent 2476500 commit c8e1705
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 3 deletions.
63 changes: 63 additions & 0 deletions tests/db/import_/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,66 @@ def test_alert__route_linking(
assert getattr(persisted_alert, entity_type) == [entity]
else:
assert getattr(persisted_alert, entity_type) == []


@pytest.fixture
def trip_for_vehicle(add_model, system_1, route_1_1, stop_1_1, stop_1_2, stop_1_3):
return add_model(
models.Trip(
id="trip_id",
route=route_1_1,
stop_times=[
models.TripStopTime(stop_sequence=1, stop=stop_1_1, future=False),
models.TripStopTime(stop_sequence=2, stop=stop_1_2, future=True),
models.TripStopTime(stop_sequence=3, stop=stop_1_3, future=True),
],
)
)


@pytest.mark.parametrize(
"provide_stop_id,provide_stop_sequence",
[[True, True], [True, False], [False, True]],
)
def test_vehicle__set_stop_simple_case(
db_session,
current_update,
trip_for_vehicle,
stop_1_3,
provide_stop_id,
provide_stop_sequence,
):
vehicle = parse.Vehicle(
id="vehicle_id",
trip_id="trip_id",
current_stop_id=stop_1_3.id if provide_stop_id else None,
current_stop_sequence=3 if provide_stop_sequence else None,
)

importdriver.run_import(current_update.pk, ParserForTesting([vehicle]))

persisted_vehicle = db_session.query(models.Vehicle).all()[0]

assert persisted_vehicle.current_stop == stop_1_3
assert persisted_vehicle.current_stop_sequence == 3


def test_vehicle__add_trip_relationship(
db_session, current_update, trip_for_vehicle, stop_1_3,
):
vehicle = parse.Vehicle(id="vehicle_id", trip_id="trip_id",)

importdriver.run_import(current_update.pk, ParserForTesting([vehicle]))

persisted_vehicle = db_session.query(models.Vehicle).all()[0]

assert persisted_vehicle.trip == trip_for_vehicle


# TODO: vehicle test cases
"""
- Ensuring the history of a trip is changed correctly
- Deleting a trip that has a vehicle in the FK!
- Move vehicle between trips
- Delete vehicle assigned to trip
"""
5 changes: 3 additions & 2 deletions transiter/db/models/vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class Vehicle(Base):

source = relationship("FeedUpdate", cascade="none")
system = relationship("System", back_populates="vehicles", cascade="none")
trip = relationship("Trip", back_populates="vehicle", cascade="none")
stop = relationship("Stop", cascade="none")
trip = relationship("Trip", back_populates="vehicle", cascade="none", uselist=False)
current_stop = relationship("Stop", cascade="none")

__table_args__ = (UniqueConstraint(system_pk, id),)

Expand All @@ -67,6 +67,7 @@ def from_parsed_vehicle(vehicle: parse.Vehicle) -> "Vehicle":
id=vehicle.id,
label=vehicle.label,
license_plate=vehicle.license_plate,
current_stop_sequence=vehicle.current_stop_sequence,
current_status=vehicle.current_status,
latitude=vehicle.latitude,
longitude=vehicle.longitude,
Expand Down
1 change: 1 addition & 0 deletions transiter/db/queries/genericqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_id_to_pk_map(
Note this method only works with entities that are direct children of the system.
"""
if ids is not None:
ids = list(ids)
id_to_pk = {id_: None for id_ in ids}
else:
id_to_pk = {}
Expand Down
25 changes: 24 additions & 1 deletion transiter/db/queries/tripqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.orm import selectinload, joinedload

from transiter.db import dbconnection, models
import typing


def list_all_from_feed(feed_pk):
Expand Down Expand Up @@ -69,7 +70,7 @@ def get_trip_pk_to_stop_time_data_list(feed_pk) -> Dict[int, List[StopTimeData]]
.order_by(models.TripStopTime.trip_pk, models.TripStopTime.stop_sequence)
)
trip_pk_to_stop_time_data_list = {}
for (trip_pk, stop_time_pk, stop_sequence, stop_pk,) in query.all():
for (trip_pk, stop_time_pk, stop_sequence, stop_pk) in query.all():
if trip_pk not in trip_pk_to_stop_time_data_list:
trip_pk_to_stop_time_data_list[trip_pk] = []
trip_pk_to_stop_time_data_list[trip_pk].append(
Expand All @@ -78,6 +79,28 @@ def get_trip_pk_to_stop_time_data_list(feed_pk) -> Dict[int, List[StopTimeData]]
return trip_pk_to_stop_time_data_list


# TODO: bulkify
def get_trip_stop_time_data(
trip_pk, stop_pk, stop_sequence
) -> typing.Optional[StopTimeData]:
session = dbconnection.get_session()
query = session.query(
models.TripStopTime.trip_pk,
models.TripStopTime.pk,
models.TripStopTime.stop_sequence,
models.TripStopTime.stop_pk,
).filter(models.TripStopTime.trip_pk == trip_pk)
if stop_pk is not None:
query = query.filter(models.TripStopTime.stop_pk == stop_pk)
if stop_sequence is not None:
query = query.filter(models.TripStopTime.stop_sequence == stop_sequence)
for (trip_pk, stop_time_pk, stop_sequence, stop_pk) in query.all():
return StopTimeData(
pk=stop_time_pk, stop_sequence=stop_sequence, stop_pk=stop_pk
)
return None


def list_all_in_route_by_pk(route_pk):
"""
List all of the Trips in a route.
Expand Down
9 changes: 9 additions & 0 deletions transiter/import_/importdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,14 @@ def sync(self, parsed_vehicles: typing.Iterable[parse.Vehicle]):
if trip is None:
continue
vehicle_id_to_trip[parsed_vehicle.id] = trip

stop_time_data = tripqueries.get_trip_stop_time_data(
trip.pk, vehicle.current_stop_pk, vehicle.current_stop_sequence
)
if stop_time_data is not None:
vehicle.current_stop_pk = stop_time_data.stop_pk
vehicle.current_stop_sequence = stop_time_data.stop_sequence

persisted_vehicles, num_added, num_updated = self._merge_entities(vehicles)

for persisted_vehicle in persisted_vehicles:
Expand Down Expand Up @@ -732,6 +740,7 @@ def _get_trip_id_to_pk_map(system, parsed_vehicles):
),
)
}
# TODO: consider using this instead
return tripqueries.get_id_to_pk_map_in_system(
system_pk,
(
Expand Down

0 comments on commit c8e1705

Please sign in to comment.