Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing errors #33

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mtapi/mtapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import urllib, contextlib, datetime, copy
import urllib.request, urllib.error
import contextlib, datetime, copy
from collections import defaultdict
from itertools import islice
from operator import itemgetter
Expand All @@ -8,10 +9,13 @@
import google.protobuf.message
from mtaproto.feedresponse import FeedResponse, Trip, TripStop, TZ
from mtapi._mtapithreader import _MtapiThreader
from typing import TypeAlias

logger = logging.getLogger(__name__)

def distance(p1, p2):
point: TypeAlias = tuple[float, float] | list[float]

def distance(p1: point, p2: point) -> float:
return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)

class Mtapi(object):
Expand Down
70 changes: 34 additions & 36 deletions mtaproto/feedresponse.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,57 @@
from mtaproto import nyct_subway_pb2
from pytz import timezone
import datetime

from pytz import timezone

from . import nyct_subway_pb2
from . import gtfs_realtime_pb2

TZ = timezone('US/Eastern')

class FeedResponse(object):

def __init__(self, response_string):
self._pb_data = nyct_subway_pb2.gtfs__realtime__pb2.FeedMessage()
gtfs_realtime_pb2.FeedMessage()
self._pb_data = gtfs_realtime_pb2.FeedMessage()
self._pb_data.ParseFromString(response_string)

def __getattr__(self, name):

if name == 'timestamp':
return datetime.datetime.fromtimestamp(self._pb_data.header.timestamp, TZ)


return getattr(self._pb_data, name)

@property
def timestamp(self):
return datetime.datetime.fromtimestamp(self._pb_data.header.timestamp, TZ)

@property
def entity(self):
return self._pb_data.entity

class Trip(object):
def __init__(self, pb_data):
def __init__(self, pb_data: gtfs_realtime_pb2.FeedEntity):
self._pb_data = pb_data

def __getattr__(self, name):

if name == 'direction':
return self._direction()
elif name == 'route_id':
if self._pb_data.trip_update.trip.route_id == 'GS':
return 'S'
else:
return self._pb_data.trip_update.trip.route_id


return getattr(self._pb_data, name)

def _direction(self):
@property
def direction(self):
trip_meta = self._pb_data.trip_update.trip.Extensions[nyct_subway_pb2.nyct_trip_descriptor]
return nyct_subway_pb2.NyctTripDescriptor.Direction.Name(trip_meta.direction)

@property
def route_id(self):
if self._pb_data.trip_update.trip.route_id == 'GS':
return 'S'
else:
return self._pb_data.trip_update.trip.route_id

def is_valid(self):
return bool(self._pb_data.trip_update)


class TripStop(object):
def __init__(self, pb_data):
def __init__(self, pb_data: gtfs_realtime_pb2.TripUpdate.StopTimeUpdate):
self._pb_data = pb_data

def __getattr__(self, name):

if name == 'time':
raw_time = self._pb_data.arrival.time or self._pb_data.departure.time
return datetime.datetime.fromtimestamp(raw_time, TZ)
elif name == 'stop_id':
return str(self._pb_data.stop_id[:3])

return getattr(self._pb_data, name)
@property
def time(self):
raw_time = self._pb_data.arrival.time or self._pb_data.departure.time
return datetime.datetime.fromtimestamp(raw_time, TZ)

@property
def stop_id(self):
return str(self._pb_data.stop_id[:3])