In [None]:
import os
os.environ["CALITP_BQ_MAX_BYTES"] = str(1_000_000_000_000) ## 1TB?

In [None]:
import geopandas as gpd
import pandas as pd
from siuba import *
import numpy as np

from segment_speed_utils import helpers, gtfs_schedule_wrangling
from shared_utils import rt_dates, gtfs_utils_v2
import folium
import itertools

In [None]:
from update_vars import (analysis_date, AM_PEAK, PM_PEAK, EXPORT_PATH, GCS_FILE_PATH, PROJECT_CRS,
SEGMENT_BUFFER_METERS, AM_PEAK, PM_PEAK, HQ_TRANSIT_THRESHOLD, MS_TRANSIT_THRESHOLD)

In [None]:
import sjoin_stops_to_segments

In [None]:
import create_aggregate_stop_frequencies

In [None]:
import importlib
importlib.reload(create_aggregate_stop_frequencies)

In [None]:
am_peak_hrs = list(range(AM_PEAK[0].hour, AM_PEAK[1].hour))
pm_peak_hrs = list(range(PM_PEAK[0].hour, PM_PEAK[1].hour))
both_peaks_hrs = am_peak_hrs + pm_peak_hrs

In [None]:
analysis_date

In [None]:
stop_times = helpers.import_scheduled_stop_times(
    analysis_date,
    get_pandas = True,
)

In [None]:
stop_times = create_aggregate_stop_frequencies.add_route_dir(stop_times, analysis_date)

In [None]:
# stop_times >> head(2)

In [None]:
st_prepped = stop_times.pipe(create_aggregate_stop_frequencies.prep_stop_times)

## multi logic

In [None]:
multi_test2 = create_aggregate_stop_frequencies.stop_times_aggregation_max_by_stop(st_prepped, analysis_date, single_route_dir=False)

## single logic

In [None]:
single_test2 = create_aggregate_stop_frequencies.stop_times_aggregation_max_by_stop(st_prepped, analysis_date, single_route_dir=True)

## create count of shared stops between each route_dir

In [None]:
min_freq = min([HQ_TRANSIT_THRESHOLD, MS_TRANSIT_THRESHOLD])

In [None]:
def get_explode_multiroute_only(
    single_route_aggregation: pd.DataFrame,
    multi_route_aggregation: pd.DataFrame,
    min_freqency: int
) -> pd.DataFrame:
    '''
    Shrink the problem space for the compute-intensive collinearity screen.
    First, get stops with any chance of qualifying as a major stop/hq corr for
    both single and multi-route aggregations.
    Then get stops that appear in multi-route qualifiers only, these will go to
    further processing.
    '''
    single_qual = single_route_aggregation >> filter(_.am_max_trips_hr >= min_freqency, _.pm_max_trips_hr >= min_freqency)
    multi_qual = multi_route_aggregation >> filter(_.am_max_trips_hr >= min_freqency, _.pm_max_trips_hr >= min_freqency)
    multi_only = multi_qual >> anti_join(_, single_qual, on=['schedule_gtfs_dataset_key', 'stop_id'])
    print(f'{multi_only.shape[0]} stops may qualify with multi-route aggregation')
    multi_only_explode = (multi_only[['schedule_gtfs_dataset_key', 'stop_id', 'route_dir']]
    .explode('route_dir')
    .sort_values(['schedule_gtfs_dataset_key','stop_id', 'route_dir'])) #  sorting crucial for next step
    return multi_only_explode

In [None]:
multi_only_explode = get_explode_multiroute_only(single_test2, multi_test2, min([HQ_TRANSIT_THRESHOLD, MS_TRANSIT_THRESHOLD]))

In [None]:
multi_only_explode >> head(3)

In [None]:
def accumulate_share_count(route_dir_exploded: pd.DataFrame):
    '''
    For use via pd.DataFrame.groupby.apply
    Accumulate the number of times each route_dir shares stops with
    each other in a dictionary (share_counts)
    '''
    global share_counts
    rt_dir = route_dir_exploded.route_dir.to_numpy()
    schedule_gtfs_dataset_key = route_dir_exploded.schedule_gtfs_dataset_key.iloc[0]
    for route_dir in rt_dir:
        route = route_dir.split('_')[0] #  don't compare opposite dirs of same route, leads to edge cases like AC Transit 45
        other_dirs = [x for x in rt_dir if x != route_dir and x.split('_')[0] != route]
        for other_dir in other_dirs:
            key = schedule_gtfs_dataset_key+'__'+route_dir+'__'+other_dir
            if key in share_counts.keys():
                share_counts[key] += 1
            else:
                share_counts[key] = 1

In [None]:
share_counts = {}
multi_only_explode.groupby(['schedule_gtfs_dataset_key', 'stop_id']).apply(accumulate_share_count)

In [None]:
# share_counts

In [None]:
s = pd.Series(share_counts.values())

In [None]:
(s[s<11]).hist()

### Which threshold?

* 8 catches Muni 48 and 66, which are somewhat marginal but not an edge case per se

In [None]:
# qualify

## lookup function/filtering steps

1. If a feed has no route_direction pairs qualifying, by definition no stops will qualify. Can exclude feed from next steps.
1. Get a list of unique feeds where at least one route_directions pair qualifies to evaluate.
1. Get stop_times filtered to that feed, and filter that to stops that only qualify with multiple routes, and route directions that pair with at least one other route_direction.
1. After that filtering, check again if stop_times includes the minimum frequency to qualify at each stop. Exclude stops where it doesn't.
1. Then... evaluate which route_directions can be aggregated at each remaining stop. From the full list of route_directions (sorted by frequency) serving the stop, use `list(itertools.combinations(this_stop_route_dirs, 2))` to get each unique pair of route_directions. Check each of those unique pairs to see if it meets the `SHARED_STOP_THRESHOLD`. If they all do, keep all stop_times entries for that stop, different route_directions can be aggregated together at that stop. If any do not, remove the least frequent route_direction and try again, until a subset passes (only keep stop_times for that subset) or until all are eliminated. Currently implemented recursively as below:

    ```
    attempting ['103_1', '101_1', '102_1', '104_1']... subsetting...
    attempting ['103_1', '101_1', '102_1']... subsetting...
    attempting ['103_1', '101_1']... matched!

    attempting ['103_1', '101_0', '101_1', '103_0']... subsetting...
    attempting ['103_1', '101_0', '101_1']... subsetting...
    attempting ['103_1', '101_0']... subsetting...
    exhausted!
    ```

1. With that filtered stop_times, recalculate stop-level frequencies as before. Only keep stops meeting the minimum frequency threshold for a major stop or HQ corridor.
1. Finally, once again apply the `SHARED_STOP_THRESHOLD` after aggregation (by ensuring at least one route_dir at each stop has >= `SHARED_STOP_THRESHOLD` frequent stops). Exclude stops that don't meet this criteria.

### edge cases:

[AC Transit 45](https://www.actransit.org/sites/default/files/timetable_files/45-2023_12_03.pdf) _Opposite directions share a same-direction loop._ __Solved__ by preventing the same route from being compared with itself in the opposite direction.

[SDMTS 944/945](https://www.sdmts.com/sites/default/files/routes/pdf/944.pdf) _Shared frequent stops are few, and these routes are isolated._ __Solved__ by once again applying the `SHARED_STOP_THRESHOLD` after aggregation (by ensuring at least one route_dir at each stop has >= `SHARED_STOP_THRESHOLD` frequent stops). Complex typology including a loop route, each pair of [944, 945, 945A(946)] has >= threshold... but not actually in the same spots!

In [None]:
SHARED_STOP_THRESHOLD = 8 #  current rec
qualify = {key: share_counts[key] for key in share_counts.keys() if share_counts[key] >= SHARED_STOP_THRESHOLD}

In [None]:
feeds_to_filter = np.unique([key.split('__')[0] for key in qualify.keys()])

In [None]:
feeds_no_qualify = np.unique([key.split('__')[0] for key in share_counts.keys() if key.split('__')[0] not in feeds_to_filter])

In [None]:
from calitp_data_analysis.tables import tbls

In [None]:
feeds_no_qualify = tbls.mart_transit_database.dim_gtfs_service_data() >> filter(_.gtfs_dataset_key.isin(feeds_no_qualify)) >> distinct(_.name, _.gtfs_dataset_key) >> collect()

In [None]:
feed_names = (tbls.mart_transit_database.dim_gtfs_service_data() >> filter(_.gtfs_dataset_key.isin(feeds_to_filter))
 >> distinct(_.name, _.gtfs_dataset_key)
 >> collect()
)

In [None]:
feed_names_filtered = feed_names >> filter(_.name.str.contains('Long'))
display(feed_names_filtered)
gtfs_dataset_key = feed_names_filtered.gtfs_dataset_key.iloc[0]

In [None]:
# dataset_key = '015d67d5b75b5cf2b710bbadadfb75f5' #  Marin
# dataset_key = '3c62ad6ee589d56eca915ce291a5df0a' #  Yolobus 42A and 42B share 5+ stops so they match, which isn't desirable.
# dataset_key = '70c8a8b71c815224299523bf2115924a' #  SacRT
# dataset_key = '63029a23cb0e73f2a5d98a345c5e2e40' #  Elk Grove
# dataset_key = 'f1b35a50955aeb498533c1c6fdafbe44' #  LBT

In [None]:
def feed_level_filter(
gtfs_dataset_key: str,
multi_only_explode: pd.DataFrame,
qualify_dict: dict,
st_prepped: pd.DataFrame
) -> pd.DataFrame:
    '''
    For a single feed, filter potential stop_times to evaluate based on if their route_dir
    appears at all in qualifying route_dir dict, recheck if there's any chance those stops
    could qualify. Further shrinks problem space for check_stop lookup step
    '''

    this_feed_qual = {key.split(gtfs_dataset_key)[1][2:]:qualify_dict[key] for key in qualify_dict.keys() if key.split('__')[0] == gtfs_dataset_key}
    qualify_pairs = [tuple(key.split('__')) for key in this_feed_qual.keys()]
    arr = np.array(qualify_pairs[0])
    for pair in qualify_pairs[1:]: arr = np.append(arr, np.array(pair))
    any_appearance = np.unique(arr)

    #  only need to check stops that qualify as multi-route only
    stops_to_eval = multi_only_explode >> filter(_.schedule_gtfs_dataset_key == gtfs_dataset_key) >> distinct(_.stop_id)
    st_prepped = st_prepped >> filter(_.schedule_gtfs_dataset_key == gtfs_dataset_key,
                                      _.stop_id.isin(stops_to_eval.stop_id),
                                     )
    print(f'{st_prepped.shape}')
    st_to_eval = st_prepped >> filter(_.route_dir.isin(any_appearance))
    print(f'{st_to_eval.shape}')
    #  cut down problem space by checking if stops still could qual after filtering for any appearance
    min_rows = min_freq * len(both_peaks_hrs)
    st_could_qual = (st_to_eval >> group_by(_.stop_id)
     >> mutate(could_qualify = _.shape[0] >= min_rows)
     >> ungroup()
     >> filter(_.could_qualify)
    )
    print(f'{st_could_qual.shape}')
    return st_could_qual, qualify_pairs

In [None]:
# st_could_qual, qualify_pairs = feed_level_filter(gtfs_dataset_key, multi_only_explode, qualify, st_prepped)

In [None]:
def check_stop(this_stop_route_dirs, qualify_pairs):
    #  check if all possible combinations included
    this_stop_route_dirs = list(this_stop_route_dirs)
    if len(this_stop_route_dirs) == 1:
        print('exhausted!')
        return []
    print(f'attempting {this_stop_route_dirs}... ', end='')
    stop_route_dir_pairs = list(itertools.combinations(this_stop_route_dirs, 2))
    checks = np.array([True if rt_dir in qualify_pairs else False for rt_dir in stop_route_dir_pairs])
    if checks.all():
        print(f'matched!')
        return this_stop_route_dirs
    else:
        print('subsetting...')
        this_stop_route_dirs.pop(-1)
        return check_stop(this_stop_route_dirs, qualify_pairs)

In [None]:
# check_stop(['no', 'nyet', 'bazz', 'fizz', 'buzz'], qualify_pairs)

In [None]:
def filter_qualifying_stops(one_stop_df, qualify_pairs):

    one_stop_df = (one_stop_df >> group_by(_.route_dir)
                >> mutate(route_dir_count = _.shape[0]) >> ungroup()
                >> arrange(-_.route_dir_count)
               )
    this_stop_route_dirs = (one_stop_df >> distinct(_.route_dir, _.route_dir_count)).route_dir.to_numpy() #  preserves sort order
    aggregation_ok_route_dirs = check_stop(this_stop_route_dirs, qualify_pairs)
    return one_stop_df >> filter(_.route_dir.isin(aggregation_ok_route_dirs))

## unify function, try looping over all feeds?

In [None]:
def collinear_filter_feed(
    gtfs_dataset_key: str,
    multi_only_explode: pd.DataFrame,
    qualify_dict: dict,
    st_prepped: pd.DataFrame
):
    
    st_could_qual, qualify_pairs = feed_level_filter(gtfs_dataset_key, multi_only_explode, qualify, st_prepped)
    st_qual_filter_1 = st_could_qual.groupby('stop_id').apply(filter_qualifying_stops, qualify_pairs=qualify_pairs)
    st_qual_filter_1 = st_qual_filter_1.reset_index(drop=True)
    if st_qual_filter_1.empty: return
    feed_key = st_qual_filter_1.feed_key.iloc[0]
    trips_per_peak_qual_1 = create_aggregate_stop_frequencies.stop_times_aggregation_max_by_stop(st_qual_filter_1, analysis_date, single_route_dir=False)
    trips_per_peak_qual_1 = trips_per_peak_qual_1 >> filter(_.am_max_trips_hr >= min_freq, _.pm_max_trips_hr >= min_freq)
    short_routes = trips_per_peak_qual_1.explode('route_dir') >> count(_.route_dir) >> filter(_.n < SHARED_STOP_THRESHOLD)
    print('short routes, all_short stops:')
    display(short_routes)
    trips_per_peak_qual_1['all_short'] = trips_per_peak_qual_1.route_dir.map(
        lambda x: np.array([True if y in list(short_routes.route_dir) else False for y in x]).all())
    display(trips_per_peak_qual_1 >> filter(_.all_short)) #  stops where _every_ shared route has less than SHARED_STOP_THRESHOLD frequent stops (even after aggregation)
    trips_per_peak_qual_2 = trips_per_peak_qual_1 >> filter(-_.all_short)
    trips_per_peak_qual_2['feed_key'] = feed_key #  for mapping in dev, can get rid of
    
    return trips_per_peak_qual_2

In [None]:
# muni_final = collinear_filter_feed(dataset_key, multi_only_explode, qualify, st_prepped)

In [None]:
# lbt = collinear_filter_feed(gtfs_dataset_key, multi_only_explode, qualify, st_prepped)

In [None]:
# %%time 40 seconds (on default user) is not too bad! 
all_collinear = pd.DataFrame()
for gtfs_dataset_key in feeds_to_filter:
    df = collinear_filter_feed(gtfs_dataset_key, multi_only_explode, qualify, st_prepped)
    all_collinear = pd.concat([df, all_collinear])

## Map single result

In [None]:
stops = helpers.import_scheduled_stops(
    analysis_date,
    get_pandas = True,
    crs = PROJECT_CRS
)

stops = stops >> inner_join(_, stop_times>>distinct(_.feed_key, _.schedule_gtfs_dataset_key), on='feed_key')

In [None]:
# gdf = stops >> inner_join(_, lbt, on = ['feed_key', 'stop_id']) >> distinct(_.stop_id, _.geometry)

In [None]:
# gdf.explore()

## Map overall results

In [None]:
min_freqency = min([HQ_TRANSIT_THRESHOLD, MS_TRANSIT_THRESHOLD])

In [None]:
single_qual = single_test2 >> filter(_.am_max_trips_hr >= min_freqency, _.pm_max_trips_hr >= min_freqency)

In [None]:
multi_qual = multi_test2 >> filter(_.am_max_trips_hr >= min_freqency, _.pm_max_trips_hr >= min_freqency, _.route_dir_count > 1)

In [None]:
multi_only = multi_qual >> anti_join(_, single_qual, on=['schedule_gtfs_dataset_key', 'stop_id'])

In [None]:
gdf = (stops >> inner_join(_, multi_only, on = ['stop_id', 'schedule_gtfs_dataset_key'])
       >> mutate(route_dir = _.route_dir.astype(str))
       >> distinct(_.stop_id, _.route_dir, _.am_max_trips_hr,
            _.pm_max_trips_hr, _.geometry)
       
      )

In [None]:
gdf2 = (stops >> inner_join(_, all_collinear, on = ['stop_id', 'schedule_gtfs_dataset_key'])
       >> mutate(route_dir = _.route_dir.astype(str))
       >> distinct(_.stop_id, _.route_dir, _.am_max_trips_hr,
            _.pm_max_trips_hr, _.geometry)
       
      )

In [None]:
gdf3 = (stops >> inner_join(_, single_qual, on = ['stop_id', 'schedule_gtfs_dataset_key'])
       >> mutate(route_dir = _.route_dir.astype(str))
       >> distinct(_.stop_id, _.route_dir, _.am_max_trips_hr,
            _.pm_max_trips_hr, _.geometry)
       
      )

In [None]:
# gdf.explore()

m = gdf.explore(color='orange')

In [None]:
m = gdf3.explore(color='blue', m=m)

In [None]:
m = gdf2.explore(m = m, color='red')

In [None]:
folium.LayerControl().add_to(m);

In [None]:
m #  8 threshold

In [None]:
gdf.shape

In [None]:
gdf2.shape

In [None]:
gdf3.shape

In [None]:
all_collinear

In [None]:
single_qual.min()

In [None]:
all_collinear.am_max_trips_hr.min()

In [None]:
all_collinear.pm_max_trips_hr.min()