In [57]:
# IMPORTS
import requests

import gtfs_realtime_NYCT_pb2
import gtfs_realtime_pb2
import polars as pl
from polars import col
import re
from PIL import Image
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import pandas as pd
from PIL import Image
import pyarrow
import json
import bisect
import math
from datetime import datetime, timedelta
import plotly.io as pio

## Read in Flat Files

In [58]:
# FLAT FILE IMPORT
stops = pl.read_csv(
    "stops.txt",
    separator=",",
    has_header=True,
    schema_overrides={"parent_station": pl.String},
)

shapes = pl.read_csv(
    "shapes.txt",
    separator=",",
    has_header=True,
)

colors = pl.read_csv("MTA_Colors_20240623.csv", separator=",", has_header=True)

In [59]:
# READ IN MAP

# Load the plot from an HTML file
with open("map_plot.json", "r") as f:
    fig_json = json.load(f)

fig = go.Figure(fig_json)

## Constants

In [60]:
STOP_STATUS = {"0": "Incoming At", "1": "Stopped At", "2": "In Transit To"}

In [61]:
API_ENDPOINTS = {
    "ACE": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace",
    "BDFM": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm",
    "G": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-g",
    "JZ": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz",
    "NQRW": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw",
    "L": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l",
    "1234567": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs",
    "SI": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si",
}

In [62]:
response = requests.get(API_ENDPOINTS["1234567"])
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(response.content)

90997

## Functions

In [63]:
def plot_map(
    coordinates,
    fig,
    trip_id,
    line,
    # hovertext,
    marker_size=15,
    marker_color="#03fca9",
):

    fig.add_trace(
        go.Scattermapbox(
            mode="markers+text",
            name=trip_id,
            lon=[coordinates[0]],
            lat=[coordinates[1]],
            text=line,
            textfont=dict(color="#ffffff"),
            marker={
                "size": marker_size,
                "color": color_lookup[line],
                # "symbol": "triangle-up",
            },
            showlegend=False,
            hoverinfo="text",
            hovertext=f"<b>Line {line}<b><br>{coordinates}",
        )
    )
    # fig.add_annotation(
    #     go.layout.Annotation(
    #         text=trip_id,
    #         showarrow=False,
    #         x=coordinates[0],
    #         y=coordinates[1],
    #         xref="x",
    #         yref="y",
    #         font=dict(size=12, color="black"),
    #         xanchor="center",
    #         yanchor="middle",
    #     )
    # )

In [64]:
def haversine(lat1, lon1, lat2, lon2):
    R = 6371  # Earth radius in kilometers
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = (
        math.sin(dlat / 2) ** 2
        + math.cos(math.radians(lat1))
        * math.cos(math.radians(lat2))
        * math.sin(dlon / 2) ** 2
    )
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    distance = R * c
    return distance

In [65]:
# Define a function to calculate distance using lead coordinates
def calculate_distance_within_line(df):

    df = df.sort("shape_pt_sequence")

    # Create lead columns for lat and lon
    df = df.with_columns(
        [
            df["shape_pt_lat"].shift(1).alias("lag_lat"),
            df["shape_pt_lon"].shift(1).alias("lag_lon"),
        ]
    )

    # Apply the Haversine function to each row and add the result as a new column
    return df.with_columns(
        [
            pl.concat_list(["shape_pt_lat", "shape_pt_lon", "lag_lat", "lag_lon"])
            .map_elements(
                lambda row: (
                    haversine(row[0], row[1], row[2], row[3])
                    if row[2] is not None and row[3] is not None
                    else None
                ),
                return_dtype=pl.Float64,
            )
            .alias("distance")
        ]
    )

In [66]:
def linear_distance(lon1, lat1, lon2, lat2, fraction):
    lat = lat1 + (lat2 - lat1) * fraction
    lon = lon1 + (lon2 - lon1) * fraction
    return lat, lon

In [67]:
def calculate_position(api_time, departure, arrival, df: pl.DataFrame, incoming=False):
    trip_time = arrival - departure
    since_departure = api_time - departure
    proportion_traveled = since_departure / trip_time
    # TODO: Calculate proportions in case a train is skipping a stop

    loc = bisect.bisect_left(
        (cum_sum := df["cum_sum"].fill_null(0).to_list()), proportion_traveled
    )

    temp = df[loc, :][0].to_dict(as_series=False)
    if incoming:
        return linear_distance(
            temp["lag_lon"][0],
            temp["lag_lat"][0],
            temp["shape_pt_lon"][0],
            temp["shape_pt_lat"][0],
            0.9,
        )
    elif proportion_traveled in cum_sum:
        return df[loc, :].select(["shape_pt_lat", "shape_pt_lon"]).row(0)
    else:
        return linear_distance(
            temp["lag_lon"][0],
            temp["lag_lat"][0],
            temp["shape_pt_lon"][0],
            temp["shape_pt_lat"][0],
            (proportion_traveled - (cum_sum[loc - 1] if loc != 1 else 0))
            / (cum_sum[loc] - cum_sum[loc - 1]),
        )

## Color cleaning

In [68]:
colors = colors.filter(col("Operator") == "New York City Subway")
colors = colors.with_columns(
    col("Service").str.split(",")
)  # Split the comma-delimited values into lists
colors = colors.explode("Service")  # Explode the lists into separate rows

In [69]:
colors

Operator,Service,Hex color,CMYK
str,str,str,str
"""New York City Subway""","""A""","""#0039A6""","""(100,56,0,0)"""
"""New York City Subway""","""C""","""#0039A6""","""(100,56,0,0)"""
"""New York City Subway""","""E""","""#0039A6""","""(100,56,0,0)"""
"""New York City Subway""","""B""","""#FF6319""","""(0,60,100,0)"""
"""New York City Subway""","""D""","""#FF6319""","""(0,60,100,0)"""
…,…,…,…
"""New York City Subway""","""3""","""#EE352E""","""(0,91,76,0)"""
"""New York City Subway""","""4""","""#00933C""","""(100,0,91,6)"""
"""New York City Subway""","""5""","""#00933C""","""(100,0,91,6)"""
"""New York City Subway""","""6""","""#00933C""","""(100,0,91,6)"""


In [70]:
color_lookup = colors.select(["Service", "Hex color"]).to_dict(as_series=False)
color_lookup = {x: y for (x, y) in zip(*color_lookup.values())}

## Shape Cleaning

In [71]:
shape_unpack_re = re.compile(r"^(\w{1}).*\.+(\w+?)([XR]).*$")


def shape_unpack(shape):
    m = re.match(shape_unpack_re, shape)
    return m.group(1), m.group(2)


shapes_clean = shapes.with_columns(
    [
        shapes["shape_id"]
        .map_elements(lambda x: shape_unpack(x)[0], return_dtype=str)
        .alias("Line"),
        shapes["shape_id"]
        .map_elements(lambda x: shape_unpack(x)[1], return_dtype=str)
        .alias("Line_Variation"),
    ]
)

## Stops Cleaning

In [72]:
stop_removal_re = r".*[NS]$"

stops = stops.filter(~stops["stop_id"].str.contains(stop_removal_re))

In [73]:
stop_unpack_re = re.compile(r"^(\w{1})(\d{2})")


def stop_unpack(stop):
    m = re.match(stop_unpack_re, stop)
    return m.group(1), m.group(2)


stops_clean = stops[["stop_id", "stop_name", "stop_lat", "stop_lon"]].with_columns(
    [
        stops["stop_id"]
        .map_elements(lambda x: stop_unpack(x)[0], return_dtype=str)
        .alias("Line"),
        stops["stop_id"]
        .map_elements(lambda x: stop_unpack(x)[1], return_dtype=str)
        .alias("Order"),
    ]
)

In [74]:
# stops_clean = stops_clean.join(line_points, left_on="Line", right_on="Line", how="left")
stops_clean = (
    stops_clean.join(colors, left_on="Line", right_on="Service", how="left")
    .with_columns(pl.col("Hex color").fill_null("#858585"))
    .select(
        ["stop_name", "stop_id", "stop_lat", "stop_lon", "Line", "Order", "Hex color"]
    )
)





In [75]:
stop_lookup = stops_clean.select(
    ["stop_id", "stop_lon", "stop_lat", "stop_name"]
).to_dict(as_series=False)
stop_lookup = {x: [(y, z), n] for (x, y, z, n) in zip(*stop_lookup.values())}

In [76]:
stop_lookup

{'101': [(-73.898583, 40.889248), 'Van Cortlandt Park-242 St'],
 '103': [(-73.90087, 40.884667), '238 St'],
 '104': [(-73.904834, 40.878856), '231 St'],
 '106': [(-73.909831, 40.874561), 'Marble Hill-225 St'],
 '107': [(-73.915279, 40.869444), '215 St'],
 '108': [(-73.918822, 40.864621), '207 St'],
 '109': [(-73.925536, 40.860531), 'Dyckman St'],
 '110': [(-73.929412, 40.855225), '191 St'],
 '111': [(-73.933596, 40.849505), '181 St'],
 '112': [(-73.940133, 40.840556), '168 St-Washington Hts'],
 '113': [(-73.94489, 40.834041), '157 St'],
 '114': [(-73.95036, 40.826551), '145 St'],
 '115': [(-73.953676, 40.822008), '137 St-City College'],
 '116': [(-73.958372, 40.815581), '125 St'],
 '117': [(-73.96411, 40.807722), '116 St-Columbia University'],
 '118': [(-73.966847, 40.803967), 'Cathedral Pkwy (110 St)'],
 '119': [(-73.968379, 40.799446), '103 St'],
 '120': [(-73.972323, 40.793919), '96 St'],
 '121': [(-73.976218, 40.788644), '86 St'],
 '122': [(-73.979917, 40.783934), '79 St'],
 '123':

## Merge Shaps and Stops

In [77]:
shapes_final = shapes_clean.join(
    stops_clean.select(["stop_lon", "stop_lat", "stop_name", "stop_id"]),
    left_on=("shape_pt_lon", "shape_pt_lat"),
    right_on=("stop_lon", "stop_lat"),
    how="left",
)





## Master train info

In [78]:
def departure_time(updates):
    try:
        return updates[0][0].departure.time
    except IndexError:
        return None


def get_stop(updates):
    try:
        return updates[0][0].stop_id
    except IndexError:
        return None

In [79]:
trains_tracked = {}
problems_log = []

In [80]:
def initialize_train_table(trains_tracked):
    for trip_id in set(
        [x.vehicle.trip.trip_id for x in feed.entity if x.vehicle.trip.trip_id]
    ):

        vehicle = [
            x.vehicle
            for x in feed.entity
            if x.HasField("vehicle") and x.vehicle.trip.trip_id == trip_id
        ]
        updates = [
            x.trip_update.stop_time_update
            for x in feed.entity
            if x.HasField("trip_update") and x.trip_update.trip.trip_id == trip_id
        ]
        trip_details = [
            x.trip_update.trip
            for x in feed.entity
            if x.HasField("trip_update") and x.trip_update.trip.trip_id == trip_id
        ]

        # Avoids trains with missing information
        if len(vehicle) == 0 and len(updates) == 0:
            problems_log[trip_id] = "Both"
            continue
        elif len(vehicle) == 0:
            problems_log[trip_id] = "Vehicle"
            continue
        elif len(updates) == 0:
            problems_log[trip_id] = "Updates"
            continue

        vehicle = vehicle[0]
        updates = updates[0]
        trip_details = trip_details[0]

        updates_dict = {
            x.stop_id: {"arrival": x.arrival.time, "departure": x.departure.time}
            for x in updates
        }
        current_status = vehicle.current_status
        current_timestamp = vehicle.timestamp
        current_stop = vehicle.stop_id
        number_stop = re.compile(r"^(\d+)")
        current_stop = number_stop.match(current_stop).groups(1)[0]
        if current_status == 1:
            if len(updates_dict) == 1:
                trains_tracked[trip_id] = {
                    "prev_departure_time": current_timestamp,
                    "prev_departure_station": current_stop,
                    "planned_next_station": None,
                    "current_station": current_stop,
                    "current_schedule": updates_dict,
                    "current_status": current_status,
                    "current_timestamp": current_timestamp,
                    "current_direction": list(updates_dict.keys())[0][-1],
                    "line": trip_details.route_id,
                }
            else:
                trains_tracked[trip_id] = {
                    "prev_departure_time": current_timestamp,
                    "prev_departure_station": current_stop,
                    "planned_next_station": number_stop.match(
                        list(updates_dict.keys())[1]
                    ).groups(1)[0],
                    "current_station": current_stop,
                    "current_schedule": updates_dict,
                    "current_status": current_status,
                    "current_timestamp": current_timestamp,
                    "current_direction": list(updates_dict.keys())[0][-1],
                    "line": trip_details.route_id,
                }
        elif current_status in (0, 2):
            # TODO: Implement if previous stop is not found, symbol appears red if previosly plotted
            if (
                trip_id not in trains_tracked
                or trains_tracked[trip_id]["planned_next_station"] != current_stop
            ):
                trains_tracked[trip_id] = {
                    "prev_departure_time": None,
                    "prev_departure_station": None,
                    "planned_next_station": current_stop,
                    "current_station": current_stop,
                    "current_schedule": updates_dict,
                    "current_status": current_status,
                    "current_timestamp": current_timestamp,
                    "current_direction": list(updates_dict.keys())[0][-1],
                    "line": trip_details.route_id,
                }
            else:
                trains_tracked[trip_id]["current_timestamp"] = current_timestamp
                trains_tracked[trip_id]["current_station"] = None
        else:
            if get_stop(updates) != trains_tracked[trip_id]["planned_next_station"]:
                trains_tracked[trip_id]["prev_departure_station"] = trains_tracked[
                    trip_id
                ]["next_station"]
                trains_tracked[trip_id]["next_station"] = get_stop(updates)
            trains_tracked[trip_id]["current_schedule"] = updates


def update_train_table(trains_tracked):
    for trip_id in set(
        [x.vehicle.trip.trip_id for x in feed.entity if x.vehicle.trip.trip_id]
    ):
        vehicle = [
            x.vehicle
            for x in feed.entity
            if x.HasField("vehicle") and x.vehicle.trip.trip_id == trip_id
        ]
        updates = [
            x.trip_update.stop_time_update
            for x in feed.entity
            if x.HasField("trip_update") and x.trip_update.trip.trip_id == trip_id
        ]

In [81]:
trains_tracked = {}
problems_log = []
initialize_train_table(trains_tracked)

In [82]:
trains_tracked

{'122050_6..N01R': {'prev_departure_time': 1721002884,
  'prev_departure_station': '639',
  'planned_next_station': '638',
  'current_station': '639',
  'current_schedule': {'639N': {'arrival': 1721002914,
    'departure': 1721002914},
   '638N': {'arrival': 1721003004, 'departure': 1721003004},
   '637N': {'arrival': 1721003064, 'departure': 1721003064},
   '636N': {'arrival': 1721003154, 'departure': 1721003154},
   '635N': {'arrival': 1721003244, 'departure': 1721003244},
   '634N': {'arrival': 1721003364, 'departure': 1721003364},
   '633N': {'arrival': 1721003424, 'departure': 1721003424},
   '632N': {'arrival': 1721003484, 'departure': 1721003484},
   '631N': {'arrival': 1721003604, 'departure': 1721003604},
   '630N': {'arrival': 1721003694, 'departure': 1721003694},
   '629N': {'arrival': 1721003784, 'departure': 1721003784},
   '628N': {'arrival': 1721003904, 'departure': 1721003904},
   '627N': {'arrival': 1721003994, 'departure': 1721003994},
   '626N': {'arrival': 172100408

In [83]:
stop_schedule = {}
stop_schedule = {stop_id: [] for stop_id in stops_clean["stop_id"].to_list()}

In [84]:
for trip, v in trains_tracked.items():
    train_stops = {
        stop[:-1]: dict(
            direction=v["current_direction"], line=v["line"], arrival=times["arrival"]
        )
        for stop, times in v["current_schedule"].items()
    }
    for stop, schedule in train_stops.items():
        stop_schedule[stop].append(schedule)

In [85]:
stop_schedule

{'101': [],
 '103': [],
 '104': [],
 '106': [],
 '107': [],
 '108': [],
 '109': [],
 '110': [],
 '111': [],
 '112': [],
 '113': [],
 '114': [],
 '115': [{'direction': 'S', 'line': '1', 'arrival': 0},
  {'direction': 'N', 'line': '1', 'arrival': 1721003725},
  {'direction': 'N', 'line': '1', 'arrival': 1721004567},
  {'direction': 'N', 'line': '1', 'arrival': 1721005178},
  {'direction': 'N', 'line': '1', 'arrival': 1721004181},
  {'direction': 'N', 'line': '1', 'arrival': 1721003380}],
 '116': [{'direction': 'S', 'line': '1', 'arrival': 1721003130},
  {'direction': 'N', 'line': '1', 'arrival': 1721003605},
  {'direction': 'N', 'line': '1', 'arrival': 1721004447},
  {'direction': 'N', 'line': '1', 'arrival': 1721005058},
  {'direction': 'N', 'line': '1', 'arrival': 1721004061},
  {'direction': 'N', 'line': '1', 'arrival': 1721003260}],
 '117': [{'direction': 'S', 'line': '1', 'arrival': 1721003280},
  {'direction': 'N', 'line': '1', 'arrival': 1721003455},
  {'direction': 'N', 'line': '

In [86]:
stop_strings = {}
for stop, arrivals in stop_schedule.items():
    stop_string = ""
    lines = set([a["line"] for a in arrivals])
    for line in sorted(lines):
        arrivals_line = [
            arrival
            for arrival in arrivals
            if arrival["line"] == line
            and datetime.now() <datetime.fromtimestamp(arrival["arrival"])
            < datetime.now() + timedelta(minutes=30)
        ]
        arrivals_line_sorted = sorted(arrivals_line, key=lambda x: x["arrival"])
        stop_string += f"<b>{line.upper()}<b><br>"
        for arrival in arrivals_line:
            stop_string += f"{datetime.fromtimestamp(arrival['arrival']).strftime("%I:%M")}<br>"
        stop_strings[stop] = stop_string

In [87]:
stop_strings["127"]

'<b>1<b><br>08:42<br>08:33<br>08:29<br>08:39<br>08:23<br>08:28<br><b>2<b><br>08:35<br>08:47<br>08:24<br>08:34<br><b>3<b><br>08:49<br>08:25<br>08:37<br>08:28<br>08:37<br>'

In [88]:
stop_schedule = {
    stop[:-1]: dict(
        direction=v["current_direction"], trip_id=trip_id, arrival=times["arrival"]
    )
    for trip_id, v in trains_tracked.items()
    for stop, times in v["current_schedule"].items()
}

## Testing

In [89]:
def route_to_shape(trip_id):
    # exact_route_to_shape = re.compile(r"^.*_(.*)$")
    route_to_shape = re.compile(r"^.*_(.*?)([RX]).*$")
    simple_route_to_shape = re.compile(r"^.*_(.*?\.{1,2}[NS]).*")
    shape = route_to_shape.search(trip_id)
    if shape:
        train_route = shapes_final.filter(pl.col("shape_id") == shape.groups(1)[0])
        if not train_route.is_empty():
            return train_route, shape[0]
    simple_shape = simple_route_to_shape.match(trip_id).groups(1)[0]
    partial_shape = [
        x
        for x in shapes_final["shape_id"].unique().to_list()
        if re.search(rf".*{simple_shape}.*", x)
    ][0]
    train_route = shapes_final.filter(pl.col("shape_id") == partial_shape)
    return train_route, partial_shape[0]

In [90]:
for trip_id in trains_tracked.keys():
    train_route, line = route_to_shape(trip_id)
    match trains_tracked[trip_id]["current_status"]:

        case 0:
            incoming = True
        case 1:
            plot_map(
                stop_lookup[trains_tracked[trip_id]["current_station"]][0],
                fig,
                trip_id,
                line,
            )
            continue

    # TODO: Implement beginning of trip versus mismatch of stations (symbol blinks red)
    if trains_tracked[trip_id]["prev_departure_station"] is None:
        continue
    prev_stop = train_route.select(
        [
            pl.arg_where(
                pl.col("stop_id") == trains_tracked[trip_id]["prev_departure_station"]
            ),
        ]
    )[0, 0]

    next_stop = train_route.select(
        [
            pl.arg_where(
                pl.col("stop_id") == trains_tracked[trip_id]["planned_next_station"]
            ),
        ]
    )[0, 0]
    in_route = train_route[prev_stop : next_stop + 1]
    result = in_route.group_by("Line", maintain_order=True).map_groups(
        calculate_distance_within_line
    )

    sums = result.select(pl.col("distance").drop_nulls()).to_series().sum()
    result = result.with_columns((result["distance"] / sums).alias("proportion"))
    result = result.with_columns(
        pl.col("proportion").cum_sum().round(7).alias("cum_sum")
    )
    api_time = trains_tracked[trip_id]["current_loc_info"].timestamp
    departure = trains_tracked[trip_id]["prev_departure_time"]
    arrival = trains_tracked[trip_id]["current_schedule"][0][0].arrival.time
    position = calculate_position(api_time, departure, arrival, result, incoming)
    plot_map(position, fig, trip_id, line)
    break

# Plotting Train

In [91]:
for stop in stop_lookup.keys():
    try:
        fig.update_traces(selector=dict(name=stop), text=stop_strings[str(stop)])
    except KeyError:
        continue
for trip in trains_tracked.keys():
    if next_stop := trains_tracked[trip]["planned_next_station"]:
        fig.update_traces(
            selector=dict(name=trip),
            hovertext=f"Next Stop: {stop_lookup[next_stop][1]}",
        )

In [92]:
app = dash.Dash(__name__)

app.layout = html.Div([dcc.Graph(id="live-map", figure=fig)])

if __name__ == "__main__":
    app.run_server(debug=True)

In [93]:
trains_tracked

{'122050_6..N01R': {'prev_departure_time': 1721002884,
  'prev_departure_station': '639',
  'planned_next_station': '638',
  'current_station': '639',
  'current_schedule': {'639N': {'arrival': 1721002914,
    'departure': 1721002914},
   '638N': {'arrival': 1721003004, 'departure': 1721003004},
   '637N': {'arrival': 1721003064, 'departure': 1721003064},
   '636N': {'arrival': 1721003154, 'departure': 1721003154},
   '635N': {'arrival': 1721003244, 'departure': 1721003244},
   '634N': {'arrival': 1721003364, 'departure': 1721003364},
   '633N': {'arrival': 1721003424, 'departure': 1721003424},
   '632N': {'arrival': 1721003484, 'departure': 1721003484},
   '631N': {'arrival': 1721003604, 'departure': 1721003604},
   '630N': {'arrival': 1721003694, 'departure': 1721003694},
   '629N': {'arrival': 1721003784, 'departure': 1721003784},
   '628N': {'arrival': 1721003904, 'departure': 1721003904},
   '627N': {'arrival': 1721003994, 'departure': 1721003994},
   '626N': {'arrival': 172100408