In [1]:
import os
os.environ["CALITP_BQ_MAX_BYTES"] = str(12_000_000_000_000)
os.environ['USE_PYGEOS'] = '0'

import pandas as pd
from siuba import *

from calitp_data_analysis.sql import query_sql
from calitp_data_analysis.tables import tbls
import shared_utils
import datetime as dt
import sample_query_materialized_tables as smpl

In [2]:
from segment_speed_utils.project_vars import (PREDICTIONS_GCS, 
                                              analysis_date)

In [3]:
import pytz

import numpy as np

In [4]:
PREDICTIONS_GCS

'gs://calitp-analytics-data/data-analyses/rt_predictions/'

In [7]:
summarized_df = pd.read_parquet(f"{PREDICTIONS_GCS}st_advance_samples_summarized_2023-03-15.parquet")

In [8]:
summarized_df >> head(3)

Unnamed: 0,trip_id,organization_name,route_type,max_advance_min,updates_per_min,sample_period
0,10002011240802-DEC22,Los Angeles County Metropolitan Transportation...,3,44.0,3.0,am
1,10002011240812-DEC22,Los Angeles County Metropolitan Transportation...,3,45.0,2.9,am
2,10002011240822-DEC22,Los Angeles County Metropolitan Transportation...,3,45.0,2.9,am


In [6]:
analysis_date = dt.datetime.fromisoformat(analysis_date)
analysis_date

datetime.datetime(2023, 3, 15, 0, 0)

In [None]:
service_levels = smpl.get_service_levels()
tu_datasets = smpl.get_tu_datasets()
all_data_service = smpl.filter_join_datasets_service(tu_datasets, service_levels)
chunks = smpl.chunk_by_svc_hours(all_data_service)

In [None]:
chunks[0]

## get/filter sched

In [None]:
sampling_periods = smpl.sampling_periods
sampling_periods

In [None]:
time = sampling_periods['am'][0].time()

In [None]:
time_to_sec = lambda time: time.hour * 60**2 + time.minute * 60

In [None]:
time_to_sec(time)

In [None]:
def get_schedule_df(sampling_period, chunk_df):
    
    start_sec = time_to_sec(sampling_period[0].time())
    end_sec = time_to_sec(sampling_period[1].time())
    
    daily_trips = (tbls.mart_gtfs.fct_daily_scheduled_trips()
                   >> filter(_.trip_first_departure_sec > start_sec,
                           _.trip_first_departure_sec < end_sec)
                   >> filter(_.gtfs_dataset_key.isin(
                       chunk_df.associated_schedule_gtfs_dataset_key))
                   >> filter(_.activity_date == analysis_date)
                   >> select(_.name, _.gtfs_dataset_key, _.trip_id,
                             _.activity_date, _.feed_key, _.route_type,
                            )
                   >> collect()
                  )
    
    chunk_tu_url_df = chunk_df >> select(_.tu_base64_url, _.associated_schedule_gtfs_dataset_key,
                                        _.organization_name)
    daily_trips =  daily_trips >> inner_join(_, chunk_tu_url_df,
                                     on = {'gtfs_dataset_key': 'associated_schedule_gtfs_dataset_key'})
    return daily_trips >> select(-_.gtfs_dataset_key)
    

In [None]:
am_0_sched = get_schedule_df(sampling_periods['am'], chunks[0])

In [None]:
am_0_sched >> head(3)

In [None]:
am_6_sched = get_schedule_df(sampling_periods['am'], chunks[6])

In [None]:
am_3_sched = get_schedule_df(sampling_periods['am'], chunks[3])

## read back in, join with sched

In [None]:
def get_period_chunk(sample_period, chunk):
    assert sample_period in ['am', 'mid', 'pm']
    assert chunk in range(7)
    return pd.read_parquet(
        f"{PREDICTIONS_GCS}st_updates_2023-03-15_{sample_period}_sample/"
        f"chunk_{chunk}.parquet")

In [None]:
am_6 = get_period_chunk('am', 6)

In [None]:
am_3 = get_period_chunk('am', 3)

In [None]:
am_0 = get_period_chunk('am', 0)

In [None]:
am_0 >> head(3)

In [None]:
def join_trips(sched_chunk_df, tu_chunk_df):
    
    sched_for_join = sched_chunk_df >> select(_.trip_id, _.tu_base64_url, _.feed_key,
                                             _.organization_name, _.route_type)
    tu_sched_joined = tu_chunk_df >> inner_join(_, sched_for_join, 
                   on = {'trip_id': 'trip_id', 'base64_url': 'tu_base64_url'})
    return tu_sched_joined

In [None]:
tu_sched_joined.columns

In [None]:
def check_subset_first_stop(tu_sched_joined):
    '''
    check if stop_sequence is complete from trip updates.
    if so, use that. Otherwise query dim_stop_times and 
    fill in (only) missing values
    '''
    if tu_sched_joined.stop_sequence.isna().any():
        print('filling in some stop_sequence from schedule')
        dim_st = (tbls.mart_gtfs.dim_stop_times()
              >> select(_.feed_key, _.trip_id, _.stop_id,
                       _.st_stop_sequence == _.stop_sequence)
              >> filter(_.feed_key.isin(tu_sched_joined.feed_key.unique()))
              >> filter(_.trip_id.isin(tu_sched_joined.trip_id.unique()))
              >> collect()
             ) # do this here...

        joined = tu_sched_joined >> inner_join(_, dim_st,
                                              on = {'feed_key': 'feed_key',
                                                   'trip_id': 'trip_id',
                                                   'stop_id': 'stop_id'})

        # important to use sequence from trip updates if present
        joined['stop_sequence'] = joined['stop_sequence'].fillna(joined['st_stop_sequence'])
    else:
        joined = tu_sched_joined
    
    first_tu_stops = (joined
                  >> group_by(_.trip_id)
                  >> filter(_.stop_sequence == _.stop_sequence.min())
                  >> ungroup()
                 )
    return first_tu_stops

In [None]:
pacific = pytz.timezone('US/Pacific')

In [None]:
localize_if_provided = lambda ts: pacific.localize(ts) if not pd.isna(ts) else ts

In [None]:
end_ts = pacific.localize(dt.datetime.combine(analysis_date, dt.time(23, 59)))

In [None]:
def add_tz_choose_col(first_stop_df):
    
    first_stop_df = (first_stop_df
                  >> mutate(tu_ts_pacific = _.trip_update_timestamp.apply(
                      lambda x: x.astimezone(pacific)))
                 )
    
    first_stop_df = (first_stop_df
                 >> mutate(arrival_time_pacific = _.arrival_time_pacific.apply(localize_if_provided))
                 >> mutate(departure_time_pacific = _.departure_time_pacific.apply(localize_if_provided))
                ).dropna(subset = ['trip_update_timestamp'])
    
    return first_stop_df

In [None]:
def calculate_advance_time(first_stop_df):
    
    df = first_stop_df
    df[['departure_time_pacific', 'arrival_time_pacific']] = df[['departure_time_pacific', 'arrival_time_pacific']].fillna(end_ts)
    df['min_arr_dep_pacific'] = df[['departure_time_pacific', 'arrival_time_pacific']].values.min(axis=1)
    df = df >> mutate(time_in_advance = _.min_arr_dep_pacific - _.tu_ts_pacific)
    return df

In [None]:
tu_sched_joined = join_trips(am_0_sched, am_0)

In [None]:
# tu_sched_joined = join_trips(am_3_sched, am_3)

In [None]:
# tu_sched_joined = join_trips(am_6_sched, am_6)

In [None]:
first_stop = check_subset_first_stop(tu_sched_joined)

In [None]:
with_tz = add_tz_choose_col(first_stop)

In [None]:
advance_calculated = calculate_advance_time(with_tz)

In [None]:
# advance_calculated

In [None]:
df2 = (advance_calculated
       >> group_by(_.trip_id, _.organization_name, _.route_type)
       >> mutate(max_advance = _.time_in_advance.max())
       >> mutate(max_advance_min = _.max_advance.apply(lambda x: x.seconds / 60))
       >> mutate(updates_per_min = _.shape[0] / _.max_advance_min)
       >> summarize(max_advance_min = np.round(_.max_advance_min.max(), 0),
                   updates_per_min = np.round(_.updates_per_min.max(), 1))
      )

In [None]:
df2 >> head(3)

## sandbox -- plot

In [None]:
df2['max_advance_min'] = df2['max_advance'].apply(lambda x: x.seconds / 60)

In [None]:
df2 = df2 >> select(-_.max_advance)

In [None]:
import altair as alt

In [None]:
df2 >> head(3)

In [None]:
alt.Chart(df2).mark_point().encode(x='max_advance_min')

In [None]:
df2.max_advance_min.median()

In [None]:
alt.Chart(df2).mark_bar().encode(
    alt.X('max_advance_min', bin=True),
    alt.Y('count()'),
    alt.Color('organization_name')
).interactive()

In [None]:
chart = alt.Chart(df2).mark_point().encode(
    x='organization_name',
    y='max_advance_min',
    color='organization_name',
    tooltip='trip_id'
).interactive()

In [None]:
import chart_utils

In [None]:
df2 >> group_by(_.organization_name) >> summarize(med_adv = _.max_advance_min.median())

In [None]:
chart_utils.chart_size(chart)

### sandbox

## Only Caltrain is missing stop_sequence...

In [None]:
for period in ['am', 'mid', 'pm']:
    for chunk in range(7):
        _df = get_period_chunk(period, chunk)
        print(f'{period}, {chunk}')
        print(_df.stop_sequence.isna().value_counts())
        del(_df)

In [None]:
df = get_period_chunk('am', 6)

In [None]:
df.stop_sequence.isna().value_counts()[True]

In [None]:
df = df >> mutate(no_seq = _.stop_sequence.isna())

In [None]:
df >> count(_.no_seq)

In [None]:
df2 = df >> group_by(_.base64_url) >> summarize(any_no_seq = _.no_seq.any()) >> filter(_.any_no_seq)

In [None]:
df2.base64_url.iloc[0]

In [None]:
chunks[6] >> filter(_.tu_base64_url == df2.base64_url.iloc[0])

# Sandbox

## schedule sandbox, match example

In [None]:
import pytz

In [None]:
pacific = pytz.timezone('US/Pacific')

In [None]:
bbb_ix_df = shared_utils.rt_utils.get_speedmaps_ix_df(analysis_date, 300)

In [None]:
bbb_ix_df

In [None]:
bbb_trips = shared_utils.rt_utils.get_trips(bbb_ix_df)

In [None]:
bbb_st = shared_utils.rt_utils.get_st(bbb_ix_df, bbb_trips)

In [None]:
bbb_st.arrival_time.iloc[0].split(':')

In [None]:
shared_utils.rt_utils.show_full_df(bbb_st >> group_by(_.trip_id) >> summarize(min_arrival = _.arrival_time.min()))

In [None]:
# df = (
#     tbls.mart_ad_hoc.fct_stop_time_updates_20230315_to_20230321()
#     >> filter(_.service_date == analysis_date,
#               # _.trip_id == '894836',
#               _.arrival_time_pacific >= pac_ts,
#               _.arrival_time_pacific < pac_ts2,
#              _.base64_url == bbb_base64)
#     >> select(_.arrival_time_pacific, _.departure_time_pacific,
#              _.key, _.gtfs_dataset_key, _.base64_url,
#               _._extract_ts, _.trip_update_timestamp, _.trip_id,
#              _.stop_sequence, _.stop_id, _.service_date)
#     >> collect()
# )

In [None]:
df.base64_url.iloc[0]

In [None]:
df.columns

In [None]:
df.to_parquet('bbb_example.parquet')

In [None]:
df.stop_sequence.isna().value_counts()

In [None]:
filtered._extract_ts.iloc[0]

In [None]:
filtered = df >> filter(_.trip_id == '893051', _.stop_sequence == 2) >> arrange(_._extract_ts)
filtered = filtered >> mutate(arrival_time_pacific = _.arrival_time_pacific.apply(lambda x: pacific.localize(x)))
filtered = filtered >> mutate(time_in_advance = _.arrival_time_pacific - _._extract_ts)
filtered = filtered >> select(-_.key, -_.gtfs_dataset_key, -_.base64_url)

In [None]:
## yep, this is what we wanna see!
## can do a tighter arrival time filter =-- 30min periods??

In [None]:
shared_utils.rt_utils.show_full_df(filtered)

In [None]:
## sampling approach: use 8a, noon, 11p for all operators?

In [None]:
pac_ts = dt.datetime.combine(analysis_date, dt.time(8, 0))

In [None]:
pac_ts_sec = int(pac_ts.timestamp())

In [None]:
pac_ts2 = pac_ts + dt.timedelta(hours = 1)

In [None]:
pac_ts2_sec = int(pac_ts2.timestamp())