In [None]:
%%capture
import warnings
warnings.filterwarnings('ignore')

import branca
import calitp_data_analysis.magics
import geopandas as gpd
import intake
import pandas as pd

from IPython.display import Markdown, HTML
from typing import Literal

from shared_utils import geography_utils, portfolio_utils
from bus_service_utils import better_bus_utils

catalog = intake.open_catalog("../bus_service_increase/*.yml")

In [None]:
# parameters cell
district = "03 - Marysville"

In [None]:
def get_data(district: int):
    speed = 25
    trips = 5

    gdf = better_bus_utils.select_highway_corridors(
        speed_dict = {"mean_speed_mph_trip_weighted": speed},
        trip_dict = {"trips_all_day_per_mi": trips}
    )
    
    district_df = gdf[gdf.caltrans_district == district].reset_index(drop=True)
    
    if len(district_df) == 0:
        gdf = better_bus_utils.get_sorted_highway_corridors()
        district_df = gdf[gdf.caltrans_district == district].reset_index(drop=True)
   
    return district_df


gdf = get_data(district)

In [None]:
%%capture_parameters
district_num = str(gdf.District.iloc[0])
district = gdf.caltrans_district.iloc[0]
district_num, district

In [None]:
keep_cols = ['Route', 'County', 'RouteType',
             'trips_peak', 'trips_all_day', 
             'trips_all_day_per_mi', 
             'mean_speed_mph_trip_weighted', 
             'geometry', 'District', 'caltrans_district',
]

plot_df = gdf[keep_cols].reindex(columns = keep_cols)

# Existing Transit on the SHN

In [None]:
# Modify rt_utils.ZERO_THIRTY_COLORSCALE to go up to 65 mph
ZERO_SIXTY_COLORSCALE = branca.colormap.step.RdYlGn_11.scale(vmin=0, vmax=65)
ZERO_SIXTY_COLORSCALE.caption = "Speed (miles per hour)"

# Map args
TILES = "CartoDB positron"

def plot_highway_corridor(
    gdf: gpd.GeoDataFrame, 
    metric: Literal["avg_speed", "daily_trips"]):
    """
    Returns a folium.Map using geopandas.explore()
    """
    
    keep_cols = ['Route', 'County', 'RouteType',
                 'trips_peak', 'trips_all_day', 
                 'trips_all_day_per_mi', 
                 'mean_speed_mph_trip_weighted', 
                 'geometry', 'District', 'caltrans_district',
    ]
    
    plot_df = gdf[keep_cols].reindex(columns = keep_cols)
    
    if metric == "avg_speed":
        if not plot_df.mean_speed_mph_trip_weighted.isnull().all():
            m = plot_df[plot_df.mean_speed_mph_trip_weighted.notna()].explore(
                "mean_speed_mph_trip_weighted", 
                # switch out colormap to allow higher speeds
                cmap = ZERO_SIXTY_COLORSCALE, 
                categorical=False, tiles = TILES
            )
        else:
            print(f"No trip data available for transit on highways in District {district_num}.")
            m = None
    elif metric == "daily_trips":
        if not (plot_df.trips_all_day_per_mi == 0).all():
         
            m = plot_df.rename(
                columns = {"trips_all_day_per_mi": "Daily Trips per Mile"}
            ).explore(
                "Daily Trips per Mile", 
                cmap = "viridis_r", #branca.colormap.linear.GnBu_09.scale(
                categorical=False, tiles = TILES)
        else:
            print(f"No transit trips on highways in District {district_num}.")
            m = None

    return m

## Average Speed on Highways

In [None]:
m = plot_highway_corridor(gdf, "avg_speed")
m

## Daily Trips on Highways

In [None]:
m = plot_highway_corridor(gdf, "daily_trips")
m

## Highway Aggregated Stats 

In [None]:
def aggregate_to_highway_across_segments(gdf: gpd.GeoDataFrame):
    group_cols = ['Route', 'County', 'RouteType']
    
    if not gdf.mean_speed_mph_trip_weighted.isnull().all():
        by_highway = geography_utils.aggregate_by_geography(
            gdf,
            group_cols,
            sum_cols = ["trips_all_day_per_mi", "trips_peak_per_mi", 
                        "stop_arrivals_all_day_per_mi"],
            # is this correct? I'm taking the mean across a trip-weighted avg by segment
            # but, segment is equally sized
            mean_cols = ["mean_speed_mph_trip_weighted"]
        )
    else:
        by_highway = geography_utils.aggregate_by_geography(
            gdf,
            group_cols,
            sum_cols = ["trips_all_day_per_mi", "trips_peak_per_mi", 
                        "stop_arrivals_all_day_per_mi"],
        )    
    return by_highway

In [None]:
highway_avg = aggregate_to_highway_across_segments(gdf)

# Drop if row has all zeroes or NA
if 'mean_speed_mph_trip_weighted' in highway_avg.columns:
    highway_avg = highway_avg[
        (highway_avg.mean_speed_mph_trip_weighted.notna()) & 
        (highway_avg.trips_all_day_per_mi > 0)
    ]
else:
    highway_avg = highway_avg[
        highway_avg.trips_all_day_per_mi > 0
    ]

In [None]:
if (('mean_speed_mph_trip_weighted' in highway_avg.columns) and 
       not (highway_avg.trips_all_day_per_mi==0).all()):
    
    rename_cols_dict = {
        'stop_arrivals_all_day_per_mi': "Daily Stop Arrivals per Mi",
        'trips_all_day_per_mi': 'Daily Trips per Mi', 
        'trips_peak_per_mi': 'Peak Trips per Mi',
        'mean_speed_mph_trip_weighted': 'Avg Speed (mph), trip-weighted'

    }

    table = portfolio_utils.style_table(
        highway_avg, 
        rename_cols = rename_cols_dict,
        one_decimal_cols = ['Daily Stop Arrivals per Mi', 
                            'Daily Trips per Mi', 'Peak Trips per Mi'],
        two_decimal_cols = ['Avg Speed (mph), trip-weighted'],
        display_table = True
    )
    
    
elif (('mean_speed_mph_trip_weighted' not in highway_avg.columns) and not
        (highway_avg.trips_all_day_per_mi==0).all()):
    
    rename_cols_dict = {
        'stop_arrivals_all_day_per_mi': "Daily Stop Arrivals per Mi",
        'trips_all_day_per_mi': 'Daily Trips per Mi', 
        'trips_peak_per_mi': 'Peak Trips per Mi',
    }
    
    table = portfolio_utils.style_table(
        highway_avg, 
        rename_cols = rename_cols_dict,
        one_decimal_cols = ['Daily Stop Arrivals per Mi', 
                            'Daily Trips per Mi', 'Peak Trips per Mi'],
        display_table = True
    )

else:
    display(
        Markdown(
            f"No trips or speed data available for District {district_num}.")
        )