# Interpolation method - break down steps to debug
* When finding nearest vp, there are stops that have later arrival times coming before earlier arrivals.
   * stop 1 arrives after stop 2. this can't happen in real life.
   * filter down to these stops and try again
   * how do we distinguish between whether stop 1 or stop 2 is the "correct" one?
* Want to add direction to deal with loop or inlining shapes --> maybe this can be expanded to all shapes
   * Do not even allow opposite direction vp to be selected as nearest vp

In [None]:
import dask.dataframe as dd
import numpy as np
import pandas as pd

from segment_speed_utils import helpers, wrangle_shapes
from segment_speed_utils.project_vars import SEGMENT_GCS, PROJECT_CRS 
                                              
from shared_utils import rt_dates

analysis_date = rt_dates.DATES["sep2023"]

STOP_SEG_DICT = helpers.get_parameters(
    "./scripts/config.yml", "stop_segments")

## Between stops, how to find stops behaving not as expected
There are erroneous calculations here.

Prior arrival time can't take place **after** arrival time. 

Ex: `trip_instance_key == "0000412e74a0e91993ce66bcbc4e3e73"`, stop 1's nearest position is after stop 2's nearest position.

Handle these along with loop or inlining shapes.

In [None]:
NEAREST_VP = f"{STOP_SEG_DICT['stage2']}_error_{analysis_date}"
STOP_ARRIVALS = f"{STOP_SEG_DICT['stage3']}_{analysis_date}"

df = pd.read_parquet(
    f"{SEGMENT_GCS}projection/{NEAREST_VP}.parquet"
)

In [None]:
stop_arrivals = pd.read_parquet(
    f"{SEGMENT_GCS}{STOP_ARRIVALS}.parquet",
    columns = ["trip_instance_key", "stop_sequence", "arrival_time"]
)

In [None]:
df.error_arrival_order.value_counts()

In [None]:
df.error_same_endpoints.value_counts()

In [None]:
df[(df.error_same_endpoints==1) & 
   (df.error_arrival_order==1)].shape

In [None]:
trip_stats = (df.groupby("trip_instance_key", 
                         observed=True, group_keys=False)
              .agg({
                  "error_same_endpoints": "mean",
                  "error_arrival_order": "mean"
              }).reset_index()
             )

In [None]:
# Very few trips are completely error-free
trip_stats[(trip_stats.error_same_endpoints==0) & 
           (trip_stats.error_arrival_order==0)].shape

In [None]:
#trip_stats.sample(10).trip_instance_key.unique()
subset_trip_keys = [
    '9fad69264acd8387150f45b27d4b2d09',
    '44a55d2fa2588a479065ef7702475ef1',
    '36070a2428e62b96368d072eb2a8fc1b',
    '7f665900c6b0879f4b9bda43b93fefe3',
    '8e8ba9993d52388539d06a46710c1dbc',
    'b301c2170c1ca49bbc1a9b600cccf643',
    '9373f5b0de977a718dea50fd90443619',
    '8415b3949147c9dc3d5ceb37863440b1',
    '984f598419c1d0830ef4618d495c1bd7',
    '815e4dd921cdcb61ad2dbb1ca5f08a39'
]

In [None]:
def check_if_surrounding_points_are_ok(df: pd.DataFrame):
    grouped_df = df.groupby("trip_instance_key", 
                            observed=True, group_keys=False
                           )
    df = df.assign(
        prior_error = (grouped_df
                       .error_arrival_order
                       .shift(1)
                      ),
        subseq_error = (grouped_df
                        .error_arrival_order
                        .shift(-1)
                       )
    )
    
    df = df.assign(
        can_be_fixed = df.apply(
            lambda x:
            1 if (x.error_arrival_order==1) and
            (x.prior_error==0) and (x.subseq_error==0)
            else 0, axis=1
        )
    )

    return df
    

In [None]:
df2 = pd.merge(
    df,
    stop_arrivals,
    on = ["trip_instance_key", "stop_sequence"],
    how = "inner"
)

In [None]:
df3 = check_if_surrounding_points_are_ok(df2)

In [None]:
df3[df3.error_arrival_order==1].shape

In [None]:
df3[(df3.error_arrival_order==1) & 
    (df3.prior_error==0) & 
    (df3.subseq_error==0)
   ].shape

In [None]:
df3[df3.trip_instance_key=="0001ad7e1ef246cf6d68599de0fdcaad"]

Nearly half of the arrival errors are surrounded by stops where arrival time is correct.

Let's hit this first and then deal with endpoints.

If almost every trip has errors, it matters where the arrival order error comes in. Must take the minimum stop sequence of a clear monotonic period and set that as the max to filter out vp_idx that occur afterwards.

In [None]:
import altair as alt

for t in subset_trip_keys:
    
    subset_df = df[df.trip_instance_key==t]
    
    chart = (alt.Chart(subset_df)
             .mark_line()
             .encode(
                 x="stop_sequence",
                 y="error_arrival_order"
             ).properties(title=f"{t}")
    )
    display(chart)
    
    chart2 = (alt.Chart(subset_df[subset_df.error_arrival_order == 0])
              .mark_line()
              .encode(
                  x="stop_sequence",
                  y="error_same_endpoints"
              )
    )
    display(chart2)

In [None]:
nearest_array = df.groupby("trip_instance_key").agg(
    {"nearest_vp_idx": lambda x: list(x)}).reset_index()

In [None]:
nearest_array = nearest_array.assign(
    array_diff = nearest_array.apply(
        lambda x: 
        np.ediff1d(np.asarray(x.nearest_vp_idx)),
        axis=1)
)

In [None]:
nearest_array = nearest_array.assign(
    wrong_times = nearest_array.apply(
        lambda x: 1 if len(np.where(x.array_diff < 0)[0] > 0)
        else 0, axis=1
    )
)

In [None]:
nearest_array.wrong_times.value_counts()

## Index into specific portions of array

In [None]:
one_trip = "bf87a17838cdaff5ba78fb70edd4f1bb"

In [None]:
def rt_trips_to_shape(analysis_date: str) -> pd.DataFrame:
    """
    Filter down trip_instance_keys from schedule to 
    trips present in vp.
    Provide shape_array_key associated with trip_instance_key.
    """
    # Get RT trips
    rt_trips = pd.read_parquet(
        f"{SEGMENT_GCS}vp_usable_{analysis_date}",
        filters = [[("trip_instance_key", "==", one_trip)]],
        columns = ["trip_instance_key"]
    ).drop_duplicates()

    # Find the shape_array_key for RT trips
    trip_to_shape = helpers.import_scheduled_trips(
        analysis_date,
        columns = ["trip_instance_key", "shape_array_key"],
        get_pandas = True
    ).merge(
        rt_trips,
        on = "trip_instance_key",
        how = "inner"
    )

    # Find whether it's loop or inlining
    shapes_loop_inlining = pd.read_parquet(
        f"{SEGMENT_GCS}stops_projected_{analysis_date}.parquet",
        columns = [
            "shape_array_key", "loop_or_inlining", 
            "stop_primary_direction", 
        ]
    ).drop_duplicates().merge(
        trip_to_shape,
        on = "shape_array_key",
        how = "inner"
    )
    
    return shapes_loop_inlining

In [None]:
trip_shape_crosswalk = rt_trips_to_shape(analysis_date)

In [None]:
vp = pd.read_parquet(
    f"{SEGMENT_GCS}vp_usable_{analysis_date}",
    filters = [[("trip_instance_key", "==", one_trip)]],
    columns = ["trip_instance_key", "vp_idx", 
                "vp_primary_direction", 
                  ]
    )

In [None]:
subset_vp = vp.vp_idx.unique()

In [None]:
projected_shape_meters = pd.read_parquet(
    f"{SEGMENT_GCS}projection/vp_projected_{analysis_date}.parquet",
    filters = [[("vp_idx", "in", subset_vp)]]
)

vp_with_projection = pd.merge(
    vp,
    projected_shape_meters,
    on = "vp_idx",
    how = "inner"
).merge(
    trip_shape_crosswalk,
        on = "trip_instance_key",
        how = "inner"
    )

In [None]:
shape_keys = trip_shape_crosswalk.shape_array_key.unique()

stops_projected = pd.read_parquet(
    f"{SEGMENT_GCS}stops_projected_{analysis_date}.parquet",
    filters = [[("shape_array_key", "in", shape_keys)]],
    columns = ["shape_array_key", "stop_sequence", "stop_id", 
               "shape_meters", "stop_primary_direction"]
).rename(columns = {"shape_meters": "stop_meters"})

In [None]:
trip_shape_cols = ["trip_instance_key", "shape_array_key"]

trip_info = (
    vp_with_projection
    .groupby(trip_shape_cols, 
              observed=True, group_keys=False)
    .agg({
        "vp_idx": lambda x: list(x),
        "shape_meters": lambda x: list(x),
        "vp_primary_direction": lambda x: list(x),
    })
    .reset_index()
    .rename(columns = {
        "vp_idx": "vp_idx_arr",
        "shape_meters": "shape_meters_arr",
        "vp_primary_direction": "vp_dir_arr"
    })
)

In [None]:
vp_to_stop = pd.merge(
    trip_info,
    stops_projected,
    on = "shape_array_key",
    how = "inner"
)

In [None]:
this_stop_direction = vp_to_stop.stop_primary_direction.iloc[0]
this_stop_meters = vp_to_stop.stop_meters.iloc[0]

In [None]:
this_stop_direction

In [None]:
opposite_to_stop_direction = wrangle_shapes.OPPOSITE_DIRECTIONS[
    this_stop_direction]

In [None]:
opposite_to_stop_direction

In [None]:
vp_dir_array = np.asarray(vp_to_stop.vp_dir_arr.iloc[0])
shape_meters_array = np.asarray(vp_to_stop.shape_meters_arr.iloc[0])
vp_meters_array = np.asarray(vp_to_stop.vp_idx_arr.iloc[0])

In [None]:
#https://stackoverflow.com/questions/16094563/numpy-get-index-where-value-is-true
valid_vp_idx_indices = (vp_dir_array != opposite_to_stop_direction).nonzero()

In [None]:
vp_meters_array[valid_vp_idx_indices]

In [None]:
for row in vp_to_stop.tail(1).itertuples():
    this_stop_meters = np.asarray(getattr(row, "vp_idx_arr"))[valid_vp_idx_indices]
    #valid_stop_meters = this_stop_meters
    #print(valid_stop_meters)
    print(this_stop_meters)

## Plot interpolated arrivals -> speed

In [None]:
df = pd.read_parquet(
    f"{SEGMENT_GCS}stop_arrivals_speed_{analysis_date}_2.parquet")

In [None]:
bins = range(0, 75, 5)

df[df.speed_mph < 80].hist(
    "speed_mph", 
     bins = bins)