In [1]:
# IMPORTS
import requests

import gtfs_realtime_NYCT_pb2
import gtfs_realtime_pb2
import polars as pl
from polars import col
import re
from PIL import Image
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import pandas as pd
from PIL import Image
import pyarrow
import json
import bisect
import math

## Read in Flat Files

In [2]:
# READ IN MAP
import plotly.io as pio

# Load the plot from an HTML file
with open("map_plot.json", "r") as f:
    fig_json = json.load(f)

fig = go.Figure(fig_json)

In [3]:
# FLAT FILE IMPORT
stops = pl.read_csv(
    "stops.txt",
    separator=",",
    has_header=True,
    schema_overrides={"parent_station": pl.String},
)

shapes = pl.read_csv(
    "shapes.txt",
    separator=",",
    has_header=True,
)

colors = pl.read_csv("MTA_Colors_20240623.csv", separator=",", has_header=True)

## Constants

In [4]:
STOP_STATUS = {"0": "Incoming At", "1": "Stopped At", "2": "In Transit To"}

In [5]:
API_ENDPOINTS = {
    "ACE": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace",
    "BDFM": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm",
    "G": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-g",
    "JZ": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz",
    "NQRW": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw",
    "L": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l",
    "1234567": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs",
    "SI": r"https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si",
}

In [6]:
response = requests.get(API_ENDPOINTS["1234567"])
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(response.content)

245830

## Functions

In [7]:
def haversine(lat1, lon1, lat2, lon2):
    R = 6371  # Earth radius in kilometers
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = (
        math.sin(dlat / 2) ** 2
        + math.cos(math.radians(lat1))
        * math.cos(math.radians(lat2))
        * math.sin(dlon / 2) ** 2
    )
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    distance = R * c
    return distance

In [8]:
# Define a function to calculate distance using lead coordinates
def calculate_distance_within_line(df):

    df = df.sort("shape_pt_sequence")

    # Create lead columns for lat and lon
    df = df.with_columns(
        [
            df["shape_pt_lat"].shift(1).alias("lag_lat"),
            df["shape_pt_lon"].shift(1).alias("lag_lon"),
        ]
    )

    # Apply the Haversine function to each row and add the result as a new column
    return df.with_columns(
        [
            pl.concat_list(["shape_pt_lat", "shape_pt_lon", "lag_lat", "lag_lon"])
            .map_elements(
                lambda row: (
                    haversine(row[0], row[1], row[2], row[3])
                    if row[2] is not None and row[3] is not None
                    else None
                ),
                return_dtype=pl.Float64,
            )
            .alias("distance")
        ]
    )

In [9]:
def linear_distance(lon1, lat1, lon2, lat2, fraction):
    lat = lat1 + (lat2 - lat1) * fraction
    lon = lon1 + (lon2 - lon1) * fraction
    return lat, lon

In [10]:
def calculate_position(api_time, departure, arrival, df: pl.DataFrame, incoming=False):
    trip_time = arrival - departure
    since_departure = api_time - departure
    proportion_traveled = since_departure / trip_time
    # TODO: Calculate proportions in case a train is skipping a stop

    loc = bisect.bisect_left(
        (cum_sum := df["cum_sum"].fill_null(0).to_list()), proportion_traveled
    )

    if incoming:
        return linear_distance(
            temp["lag_lon"][0],
            temp["lag_lat"][0],
            temp["shape_pt_lon"][0],
            temp["shape_pt_lat"][0],
            0.9,
        )
    elif proportion_traveled in cum_sum:
        return df[loc, :].select(["shape_pt_lat", "shape_pt_lon"]).row(0)
    else:
        temp = df[loc, :][0].to_dict(as_series=False)
        return linear_distance(
            temp["lag_lon"][0],
            temp["lag_lat"][0],
            temp["shape_pt_lon"][0],
            temp["shape_pt_lat"][0],
            (proportion_traveled - (cum_sum[loc - 1] if loc != 1 else 0))
            / (cum_sum[loc] - cum_sum[loc - 1]),
        )

In [11]:
def plot_map(coordinates, fig, trip_id):
    fig.add_trace(
        go.Scattermapbox(
            mode="markers",
            name=trip_id,
            lon=coordinates[0],
            lat=coordinates[1],
            text=trip_id,
            marker={"size": 8, "color": "blue"},
            showlegend=False,
        )
    )

## Color cleaning

In [12]:
colors = colors.filter(col("Operator") == "New York City Subway")
colors = colors.with_columns(
    col("Service").str.split(",")
)  # Split the comma-delimited values into lists
colors = colors.explode("Service")  # Explode the lists into separate rows

## Shape Cleaning

In [13]:
shape_unpack_re = re.compile(r"^(\w{1}).*\.+(\w+)$")


def shape_unpack(shape):
    m = re.match(shape_unpack_re, shape)
    return m.group(1), m.group(2)


shapes_clean = shapes.with_columns(
    [
        shapes["shape_id"]
        .map_elements(lambda x: shape_unpack(x)[0], return_dtype=str)
        .alias("Line"),
        shapes["shape_id"]
        .map_elements(lambda x: shape_unpack(x)[1], return_dtype=str)
        .alias("Line_Variation"),
    ]
)

## Stops Cleaning

In [14]:
stop_removal_re = r".*[NS]$"

stops = stops.filter(~stops["stop_id"].str.contains(stop_removal_re))

In [15]:
stop_unpack_re = re.compile(r"^(\w{1})(\d{2})")


def stop_unpack(stop):
    m = re.match(stop_unpack_re, stop)
    return m.group(1), m.group(2)


stops_clean = stops[["stop_id", "stop_name", "stop_lat", "stop_lon"]].with_columns(
    [
        stops["stop_id"]
        .map_elements(lambda x: stop_unpack(x)[0], return_dtype=str)
        .alias("Line"),
        stops["stop_id"]
        .map_elements(lambda x: stop_unpack(x)[1], return_dtype=str)
        .alias("Order"),
    ]
)

In [16]:
# stops_clean = stops_clean.join(line_points, left_on="Line", right_on="Line", how="left")
stops_clean = (
    stops_clean.join(colors, left_on="Line", right_on="Service", how="left")
    .with_columns(pl.col("Hex color").fill_null("#858585"))
    .select(
        ["stop_name", "stop_id", "stop_lat", "stop_lon", "Line", "Order", "Hex color"]
    )
)

In [17]:
stop_lookup = stops_clean.select(['stop_id', 'stop_lat', 'stop_lon']).to_dict(as_series=False)
stop_lookup = {x: (y,z) for (x,y,z) in zip(*stop_lookup.values())}


## Merge Shaps and Stops

In [18]:
shapes_final = shapes_clean.join(
    stops_clean.select(["stop_lon", "stop_lat", "stop_name", "stop_id"]),
    left_on=("shape_pt_lon", "shape_pt_lat"),
    right_on=("stop_lon", "stop_lat"),
    how="left",
)

## Master train info

In [19]:
def departure_time(updates):
    try:
        return updates[0][0].departure.time
    except IndexError:
        return None


def get_stop(updates):
    try:
        return updates[0][0].stop_id
    except IndexError:
        return None

In [20]:
trains_tracked = {}
def initialize_train_table(trains_tracked):
    for trip_id in set(
        [x.vehicle.trip.trip_id for x in feed.entity if x.vehicle.trip.trip_id]
    ):
            
        vehicle = [
            x.vehicle
            for x in feed.entity
            if x.HasField("vehicle") and x.vehicle.trip.trip_id == trip_id
        ]
        updates = [
            x.trip_update.stop_time_update
            for x in feed.entity
            if x.HasField("trip_update") and x.trip_update.trip.trip_id == trip_id
        ]

        if trip_id not in trains_tracked:
            trains_tracked[trip_id] = {
                "prev_departure_time": departure_time(updates),
                "prev_departure_station": get_stop(updates),
                "next_station": get_stop(updates),
                "current_schedule": updates,
                "current_loc_info": vehicle,
            }
        else:
            if get_stop(updates) != trains_tracked[trip_id]['next station']:
                trains_tracked[trip_id]['prev_departure_station'] = trains_tracked[trip_id]['next_station']
                trains_tracked[trip_id]['next_station'] = get_stop(updates)
            trains_tracked[trip_id]["current_schedule"] = updates
        
def update_train_table(trains_tracked):
    for trip_id in set(
            [x.vehicle.trip.trip_id for x in feed.entity if x.vehicle.trip.trip_id]
    ):
        vehicle = [
            x.vehicle
            for x in feed.entity
            if x.HasField("vehicle") and x.vehicle.trip.trip_id == trip_id
        ]
        updates = [
            x.trip_update.stop_time_update
            for x in feed.entity
            if x.HasField("trip_update") and x.trip_update.trip.trip_id == trip_id
        ]


In [28]:
initialize_train_table(trains_tracked)

KeyError: 'next station'

## Testing

In [22]:
train_route = shapes_final.filter(pl.col("shape_id") == "4..S01R")

In [23]:
prev_stop = train_route.select(
    [
        pl.arg_where(pl.col("stop_id") == "142"),
    ]
)[0, 0]

next_stop = train_route.select(
    [
        pl.arg_where(pl.col("stop_id") == "139"),
    ]
)[0, 0]

In [24]:
in_route = train_route[prev_stop : next_stop + 1]

TypeError: slice indices must be integers or None or have an __index__ method

In [None]:
# Group by 'Line', sort each group by 'line_order', and apply the distance calculation
result = in_route.group_by("Line", maintain_order=True).map_groups(
    calculate_distance_within_line
)


sums = result.select(pl.col("distance").drop_nulls()).to_series().sum()
result = result.with_columns((result["distance"] / sums).alias("proportion"))
result = result.with_columns(pl.col("proportion").cum_sum().round(7).alias("cum_sum"))

In [None]:
result

In [None]:
calculate_position(1720729868, 1720729808, 1720729888, result)

In [None]:
coords = [
    calculate_position(x, 1720729828, 1720729888, result)
    for x in range(1720729828, 1720729889)
]

In [None]:
import plotly.graph_objects as go

# Example list of coordinates (latitude, longitude)


# Separate the list of tuples into two lists: latitudes and longitudes
latitudes, longitudes = zip(*coords)

# Create a scatter plot on a map
fig = go.Figure(
    go.Scattermapbox(
        mode="markers+lines",
        lon=longitudes,
        lat=latitudes,
        marker={"size": 10},
        line=dict(width=2, color="blue"),
    )
)

# Set the layout of the map
fig.update_layout(
    mapbox={
        "style": "open-street-map",
        "center": {
            "lat": sum(latitudes) / len(latitudes),
            "lon": sum(longitudes) / len(longitudes),
        },
        "zoom": 10,
    },
    showlegend=False,
)

# Show the plot
fig.show()

In [None]:
[
    x
    for x in feed.entity
    if any(
        [
            x.vehicle.trip.trip_id == "109300_2..N01R",
            x.trip_update.trip.trip_id == "109300_2..N01R",
        ]
    )
]

In [None]:
train_4 = [
    x
    for x in feed.entity
    if any(
        [
            x.vehicle.trip.trip_id == "059250_5..N66R",
            x.trip_update.trip.trip_id == "059250_5..N66R",
        ]
    )
]

In [None]:
train_4

In [None]:
trains_tracked["059250_5..N66R"]["current_schedule"][0][0].arrival.time

In [None]:
test = [x for x in feed.entity if x.trip_update.trip.trip_id == "059250_5..N66R"]
test[0]

In [None]:
trains_tracked = {}

In [None]:
list(list(trains_tracked.values())[0].values())[2]

In [None]:
[list(x.values())[2] for x in list(trains_tracked.values())]

In [None]:
route_to_shape = re.compile(r"^.*_(.*)$")

In [None]:
stops_clean.head()

In [None]:
for trip_id in trains_tracked.keys():
    current_status = trains_tracked[trip_id]["current_loc_info"][0].current_status
    next_stop = trains_tracked[trip_id]["current_loc_info"][0].stop_id
    next_stop = re.match(r"^(\d+)", next_stop).groups(1)

    match current_status:
        case 0:
            incoming = True
        case 1:
            plot_map(stop_lookup[next_stop])
        case 2:
            next_stop = trains_tracked[trip_id]["current_loc_info"][0].stop_id

    train_route = shapes_final.filter(
        pl.col("shape_id") == route_to_shape.match(trip_id).groups(1)[0]
    )
    # trains_tracked[trip_id]['prev_departure_time']
    prev_stop = train_route.select(
        [
            pl.arg_where(
                pl.col("stop_id") == trains_tracked[trip_id]["prev_departure_station"]
            ),
        ]
    )[0, 0]

    next_stop1 = train_route.select(
        [
            pl.arg_where(pl.col("stop_id") == next_stop),
        ]
    )[0, 0]
    in_route = train_route[prev_stop : next_stop1 + 1]
    result = in_route.group_by("Line", maintain_order=True).map_groups(
        calculate_distance_within_line
    )

    sums = result.select(pl.col("distance").drop_nulls()).to_series().sum()
    result = result.with_columns((result["distance"] / sums).alias("proportion"))
    result = result.with_columns(
        pl.col("proportion").cum_sum().round(7).alias("cum_sum")
    )
    api_time = trains_tracked[trip_id]["current_loc_info"].timestamp
    departure = trains_tracked[trip_id]["prev_departure_time"]
    arrival = trains_tracked[trip_id]["current_schedule"][0][0].arrival.time
    position = calculate_position(api_time, departure, arrival, result, incoming)
    plot_map(position, fig, trip_id)
    break
    

# Plotting Train

In [None]:
fig.show()