Skip to content

Commit

Permalink
feat: partition stop times tables (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarrco committed Nov 9, 2023
2 parents 0ebb0ee + 20b035f commit 40c5b43
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 93 deletions.
80 changes: 80 additions & 0 deletions alembic/versions/7c12f6bfe3c6_merge_trip_into_stoptime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Merge Trip into StopTime
Revision ID: 7c12f6bfe3c6
Revises: 1f2c7b1eec8b
Create Date: 2023-11-05 09:16:46.640362
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = '7c12f6bfe3c6'
down_revision = '3e63dbd74ceb'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.drop_constraint('stop_times_trip_id_fkey', 'stop_times', type_='foreignkey')

# Create new fields as nullable true temporarily
op.add_column('stop_times', sa.Column('orig_id', sa.String(), nullable=True))
op.add_column('stop_times', sa.Column('dest_text', sa.String(), nullable=True))
op.add_column('stop_times', sa.Column('number', sa.Integer(), nullable=True))
op.add_column('stop_times', sa.Column('orig_dep_date', sa.Date(), nullable=True))
op.add_column('stop_times', sa.Column('route_name', sa.String(), nullable=True))
op.add_column('stop_times', sa.Column('source', sa.String(), server_default='treni', nullable=True))

# populate new fields with data from trips through stop_times.trip_id
op.execute('''
UPDATE stop_times
SET
orig_id = trips.orig_id,
dest_text = trips.dest_text,
number = trips.number,
orig_dep_date = trips.orig_dep_date,
route_name = trips.route_name,
source = trips.source
FROM trips
WHERE stop_times.trip_id = trips.id
''')

# convert new fields to not nullable
op.alter_column('stop_times', 'orig_id', nullable=False)
op.alter_column('stop_times', 'dest_text', nullable=False)
op.alter_column('stop_times', 'number', nullable=False)
op.alter_column('stop_times', 'orig_dep_date', nullable=False)
op.alter_column('stop_times', 'route_name', nullable=False)
op.alter_column('stop_times', 'source', nullable=False)

# drop trip_id column
op.drop_column('stop_times', 'trip_id')

# drop trips table
op.drop_table('trips')


def downgrade() -> None:
op.create_table('trips',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('orig_id', sa.String(), autoincrement=False, nullable=False),
sa.Column('dest_text', sa.String(), autoincrement=False, nullable=False),
sa.Column('number', sa.Integer(), autoincrement=False, nullable=False),
sa.Column('orig_dep_date', sa.Date(), autoincrement=False, nullable=False),
sa.Column('route_name', sa.String(), autoincrement=False, nullable=False),
sa.Column('source', sa.String(), server_default='treni', autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('source', 'number', 'orig_dep_date',
name='trips_source_number_orig_dep_date_key')
)

op.add_column('stop_times', sa.Column('trip_id', sa.INTEGER(), autoincrement=False, nullable=False))
op.create_foreign_key('stop_times_trip_id_fkey', 'stop_times', 'trips', ['trip_id'], ['id'], ondelete='CASCADE')

op.drop_column('stop_times', 'source')
op.drop_column('stop_times', 'route_name')
op.drop_column('stop_times', 'orig_dep_date')
op.drop_column('stop_times', 'number')
op.drop_column('stop_times', 'dest_text')
op.drop_column('stop_times', 'orig_id')
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Convert stop_times to partitioned table
Revision ID: d55702afa188
Revises: 1f2c7b1eec8b
Create Date: 2023-10-29 16:09:44.815425
"""
from alembic import op


# revision identifiers, used by Alembic.
revision = 'd55702afa188'
down_revision = '7c12f6bfe3c6'
branch_labels = None
depends_on = None


# Define the migration
def upgrade():
# remove foreign key "stop_times_stop_id_fkey"
op.drop_constraint('stop_times_stop_id_fkey', 'stop_times', type_='foreignkey')

# rename table "stop_times" to "stop_times_reg"
op.rename_table('stop_times', 'stop_times_reg')

# create the partitioned table "stop_times" for field "orig_dep_date"
op.execute("""
CREATE TABLE stop_times (
id SERIAL NOT NULL,
stop_id character varying NOT NULL,
sched_arr_dt timestamp without time zone,
sched_dep_dt timestamp without time zone,
platform character varying,
orig_dep_date date NOT NULL,
orig_id character varying NOT NULL,
dest_text character varying NOT NULL,
number integer NOT NULL,
route_name character varying NOT NULL,
source character varying,
CONSTRAINT stop_times_stop_id_fkey FOREIGN key(stop_id) REFERENCES stops(id)
) PARTITION BY RANGE (orig_dep_date);
CREATE UNIQUE INDEX stop_times_unique_idx ON stop_times(stop_id, number, source, orig_dep_date);
""")


def downgrade():
# drop the partitioned table "stop_times"
op.drop_table('stop_times')

# rename table "stop_times_reg" to "stop_times"
op.rename_table('stop_times_reg', 'stop_times')

# add foreign key "stop_times_stop_id_fkey"
op.create_foreign_key('stop_times_stop_id_fkey', 'stop_times', 'stops', ['stop_id'], ['id'])
24 changes: 23 additions & 1 deletion save_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
from datetime import date, timedelta

from sqlalchemy import inspect

from server.GTFS import GTFS
from server.sources import session
from server.base.models import StopTime
from server.sources import engine, session
from server.trenitalia import Trenitalia
from server.typesense import connect_to_typesense

Expand All @@ -26,6 +30,24 @@ def run():
Trenitalia(session, typesense, force_update_stations=force_update_stations),
]

session.commit()

today = date.today()

for i in range(3):
day: date = today + timedelta(days=i)
partition = StopTime.create_partition(day)
if not inspect(engine).has_table(partition.__table__.name):
partition.__table__.create(bind=engine)

while True:
i = -2
day = today + timedelta(days=i)
try:
StopTime.detach_partition(day)
except Exception:
break

for source in sources:
try:
source.save_data()
Expand Down
13 changes: 6 additions & 7 deletions server/GTFS/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlalchemy import select, func
from tqdm import tqdm

from server.base import Source, Station, Stop, TripStopTime
from server.base import Source, Station, Stop, TripStopTime, StopTime
from .clustering import get_clusters_of_stops, get_loc_from_stop_and_cluster
from .models import CStop

Expand Down Expand Up @@ -309,13 +309,12 @@ def get_sqlite_stop_times(self, day: date, start_time: time, end_time: time, lim

def search_lines(self, name):
today = date.today()
from server.base import Trip
trips = self.session.execute(
select(func.max(Trip.number), Trip.dest_text)\
.filter(Trip.orig_dep_date == today)\
.filter(Trip.route_name == name)\
.group_by(Trip.dest_text)\
.order_by(func.count(Trip.id).desc()))\
select(func.max(StopTime.number), StopTime.dest_text) \
.filter(StopTime.orig_dep_date == today) \
.filter(StopTime.route_name == name) \
.group_by(StopTime.dest_text) \
.order_by(func.count(StopTime.number).desc())) \
.all()

results = [(trip[0], name, trip[1]) for trip in trips]
Expand Down
112 changes: 91 additions & 21 deletions server/base/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import date, datetime
from datetime import date, datetime, timedelta
from typing import Optional

from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey, UniqueConstraint, event
from sqlalchemy.orm import declarative_base, Mapped, mapped_column, relationship, declared_attr
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.sql.ddl import DDL

Base = declarative_base()

Expand All @@ -29,6 +31,7 @@ def as_dict(self):
'source': self.source
}


class Stop(Base):
__tablename__ = 'stops'

Expand All @@ -39,35 +42,102 @@ class Stop(Base):
station_id: Mapped[str] = mapped_column(ForeignKey('stations.id'))
station: Mapped[Station] = relationship('Station', back_populates='stops')
source: Mapped[Optional[str]]
stop_times = relationship('StopTime', back_populates='stop', cascade='all, delete-orphan')
active: Mapped[bool] = mapped_column(server_default='true')


class Trip(Base):
__tablename__ = 'trips'

class PartitionByOrigDepDateMeta(DeclarativeMeta):
def __new__(cls, clsname, bases, attrs, *, partition_by):
@classmethod
def get_partition_name(cls_, key):
return f'{cls_.__tablename__}_{key}'

@classmethod
def create_partition(cls_, day: date):
key = day.strftime('%Y%m%d')
if key not in cls_.partitions:
Partition = type(
f'{clsname}{key}',
bases,
{'__tablename__': cls_.get_partition_name(key)}
)

Partition.__table__.add_is_dependent_on(cls_.__table__)

day_plus_one = day + timedelta(days=1)
event.listen(
Partition.__table__,
'after_create',
DDL(
f"""
ALTER TABLE {cls_.__tablename__}
ATTACH PARTITION {Partition.__tablename__}
FOR VALUES FROM ('{day}') TO ('{day_plus_one}')
"""
)
)

cls_.partitions[key] = Partition

return cls_.partitions[key]

@classmethod
def detach_partition(cls_, day: date):
key = day.strftime('%Y%m%d')
if key not in cls_.partitions:
raise Exception(f'Partition {key} does not exist')
Partition = type(
f'{clsname}{key}',
bases,
{'__tablename__': cls_.get_partition_name(key)}
)
event.listen(
Partition.__table__,
'after_create',
DDL(
f"""
ALTER TABLE {cls_.__tablename__}
DETACH PARTITION {Partition.__tablename__}
"""
)
)

attrs.update(
{
'__table_args__': attrs.get('__table_args__', ())
+ (dict(postgresql_partition_by=f'RANGE({partition_by})'),),
'partitions': {},
'partitioned_by': partition_by,
'get_partition_name': get_partition_name,
'create_partition': create_partition,
'detach_partition': detach_partition
}
)

return super().__new__(cls, clsname, bases, attrs)


class StopTimeMixin:
id: Mapped[int] = mapped_column(primary_key=True)
sched_arr_dt: Mapped[Optional[datetime]]
sched_dep_dt: Mapped[Optional[datetime]]
orig_dep_date: Mapped[date]
platform: Mapped[Optional[str]]
orig_id: Mapped[str]
dest_text: Mapped[str]
number: Mapped[int]
orig_dep_date: Mapped[date]
route_name: Mapped[str]
source: Mapped[str] = mapped_column(server_default='treni')
stop_times = relationship('StopTime', back_populates='trip', cascade='all, delete-orphan', passive_deletes=True)

__table_args__ = (UniqueConstraint('source', 'number', 'orig_dep_date'),)
@declared_attr
def stop_id(self) -> Mapped[str]:
return mapped_column(ForeignKey('stops.id'))

@declared_attr
def stop(self) -> Mapped[Stop]:
return relationship('Stop', foreign_keys=self.stop_id)

class StopTime(Base):
__tablename__ = 'stop_times'

id: Mapped[int] = mapped_column(primary_key=True)
trip_id: Mapped[int] = mapped_column(ForeignKey('trips.id', ondelete='CASCADE'))
trip: Mapped[Trip] = relationship('Trip', back_populates='stop_times')
stop_id: Mapped[str] = mapped_column(ForeignKey('stops.id'))
stop: Mapped[Stop] = relationship('Stop', back_populates='stop_times')
sched_arr_dt: Mapped[Optional[datetime]]
sched_dep_dt: Mapped[Optional[datetime]]
platform: Mapped[Optional[str]]
class StopTime(StopTimeMixin, Base, metaclass=PartitionByOrigDepDateMeta, partition_by='orig_dep_date'):
__tablename__ = 'stop_times'

__table_args__ = (UniqueConstraint('trip_id', 'stop_id'),)
__table_args__ = (UniqueConstraint("stop_id", "number", "source", "orig_dep_date"),)
Loading

0 comments on commit 40c5b43

Please sign in to comment.